Zen Stores
zenml.zen_stores
special
ZenStores define ways to store ZenML relevant data locally or remotely.
base_zen_store
Base Zen Store implementation.
BaseZenStore (BaseModel, ZenStoreInterface, ABC)
Base class for accessing and persisting ZenML core objects.
Attributes:
Name | Type | Description |
---|---|---|
config |
StoreConfiguration |
The configuration of the store. |
Source code in zenml/zen_stores/base_zen_store.py
class BaseZenStore(
BaseModel,
ZenStoreInterface,
ABC,
):
"""Base class for accessing and persisting ZenML core objects.
Attributes:
config: The configuration of the store.
"""
config: StoreConfiguration
TYPE: ClassVar[StoreType]
CONFIG_TYPE: ClassVar[Type[StoreConfiguration]]
@model_validator(mode="before")
@classmethod
@before_validator_handler
def convert_config(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Method to infer the correct type of the config and convert.
Args:
data: The provided configuration object, can potentially be a
generic object
Raises:
ValueError: If the provided config object's type does not match
any of the current implementations.
Returns:
The converted configuration object.
"""
if data["config"].type == StoreType.SQL:
from zenml.zen_stores.sql_zen_store import SqlZenStoreConfiguration
data["config"] = SqlZenStoreConfiguration(
**data["config"].model_dump()
)
elif data["config"].type == StoreType.REST:
from zenml.zen_stores.rest_zen_store import (
RestZenStoreConfiguration,
)
data["config"] = RestZenStoreConfiguration(
**data["config"].model_dump()
)
else:
raise ValueError(
f"Unknown type '{data['config'].type}' for the configuration."
)
return data
# ---------------------------------
# Initialization and configuration
# ---------------------------------
def __init__(
self,
skip_default_registrations: bool = False,
**kwargs: Any,
) -> None:
"""Create and initialize a store.
Args:
skip_default_registrations: If `True`, the creation of the default
stack and user in the store will be skipped.
**kwargs: Additional keyword arguments to pass to the Pydantic
constructor.
Raises:
RuntimeError: If the store cannot be initialized.
AuthorizationException: If the store cannot be initialized due to
authentication errors.
"""
super().__init__(**kwargs)
try:
self._initialize()
# Handle cases where the ZenML server is not available
except ConnectionError as e:
error_message = (
"Cannot connect to the ZenML database because the ZenML server "
f"at {self.url} is not running."
)
if urlparse(self.url).hostname in ["localhost", "127.0.0.1"]:
recommendation = (
"Please run `zenml down` and `zenml up` to restart the "
"server."
)
else:
recommendation = (
"Please run `zenml disconnect` and `zenml connect --url "
f"{self.url}` to reconnect to the server."
)
raise RuntimeError(f"{error_message}\n{recommendation}") from e
except AuthorizationException as e:
raise AuthorizationException(
f"Authorization failed for store at '{self.url}'. Please check "
f"your credentials: {str(e)}"
)
except Exception as e:
zenml_pro_extra = ""
if ".zenml.io" in self.url:
zenml_pro_extra = (
ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT
)
raise RuntimeError(
f"Error initializing {self.type.value} store with URL "
f"'{self.url}': {str(e)}" + zenml_pro_extra
) from e
if not skip_default_registrations:
logger.debug("Initializing database")
self._initialize_database()
else:
logger.debug("Skipping database initialization")
@staticmethod
def get_store_class(store_type: StoreType) -> Type["BaseZenStore"]:
"""Returns the class of the given store type.
Args:
store_type: The type of the store to get the class for.
Returns:
The class of the given store type or None if the type is unknown.
Raises:
TypeError: If the store type is unsupported.
"""
if store_type == StoreType.SQL:
from zenml.zen_stores.sql_zen_store import SqlZenStore
return SqlZenStore
elif store_type == StoreType.REST:
from zenml.zen_stores.rest_zen_store import RestZenStore
return RestZenStore
else:
raise TypeError(
f"No store implementation found for store type "
f"`{store_type.value}`."
)
@staticmethod
def get_store_config_class(
store_type: StoreType,
) -> Type["StoreConfiguration"]:
"""Returns the store config class of the given store type.
Args:
store_type: The type of the store to get the class for.
Returns:
The config class of the given store type.
"""
store_class = BaseZenStore.get_store_class(store_type)
return store_class.CONFIG_TYPE
@staticmethod
def get_store_type(url: str) -> StoreType:
"""Returns the store type associated with a URL schema.
Args:
url: The store URL.
Returns:
The store type associated with the supplied URL schema.
Raises:
TypeError: If no store type was found to support the supplied URL.
"""
from zenml.zen_stores.rest_zen_store import RestZenStoreConfiguration
from zenml.zen_stores.sql_zen_store import SqlZenStoreConfiguration
if SqlZenStoreConfiguration.supports_url_scheme(url):
return StoreType.SQL
elif RestZenStoreConfiguration.supports_url_scheme(url):
return StoreType.REST
else:
raise TypeError(f"No store implementation found for URL: {url}.")
@staticmethod
def create_store(
config: StoreConfiguration,
skip_default_registrations: bool = False,
**kwargs: Any,
) -> "BaseZenStore":
"""Create and initialize a store from a store configuration.
Args:
config: The store configuration to use.
skip_default_registrations: If `True`, the creation of the default
stack and user in the store will be skipped.
**kwargs: Additional keyword arguments to pass to the store class
Returns:
The initialized store.
"""
logger.debug(f"Creating store with config '{config}'...")
store_class = BaseZenStore.get_store_class(config.type)
store = store_class(
config=config,
skip_default_registrations=skip_default_registrations,
**kwargs,
)
return store
@staticmethod
def get_default_store_config(path: str) -> StoreConfiguration:
"""Get the default store configuration.
The default store is a SQLite store that saves the DB contents on the
local filesystem.
Args:
path: The local path where the store DB will be stored.
Returns:
The default store configuration.
"""
from zenml.zen_stores.sql_zen_store import SqlZenStoreConfiguration
config = SqlZenStoreConfiguration(
type=StoreType.SQL,
url=SqlZenStoreConfiguration.get_local_url(path),
secrets_store=SqlSecretsStoreConfiguration(
type=SecretsStoreType.SQL,
),
)
return config
def _initialize_database(self) -> None:
"""Initialize the database on first use."""
@property
def url(self) -> str:
"""The URL of the store.
Returns:
The URL of the store.
"""
return self.config.url
@property
def type(self) -> StoreType:
"""The type of the store.
Returns:
The type of the store.
"""
return self.TYPE
def validate_active_config(
self,
active_workspace_name_or_id: Optional[Union[str, UUID]] = None,
active_stack_id: Optional[UUID] = None,
config_name: str = "",
) -> Tuple[WorkspaceResponse, StackResponse]:
"""Validate the active configuration.
Call this method to validate the supplied active workspace and active
stack values.
This method is guaranteed to return valid workspace ID and stack ID
values. If the supplied workspace and stack are not set or are not valid
(e.g. they do not exist or are not accessible), the default workspace and
default workspace stack will be returned in their stead.
Args:
active_workspace_name_or_id: The name or ID of the active workspace.
active_stack_id: The ID of the active stack.
config_name: The name of the configuration to validate (used in the
displayed logs/messages).
Returns:
A tuple containing the active workspace and active stack.
"""
active_workspace: WorkspaceResponse
if active_workspace_name_or_id:
try:
active_workspace = self.get_workspace(
active_workspace_name_or_id
)
except KeyError:
active_workspace = self._get_default_workspace()
logger.warning(
f"The current {config_name} active workspace is no longer "
f"available. Resetting the active workspace to "
f"'{active_workspace.name}'."
)
else:
active_workspace = self._get_default_workspace()
logger.info(
f"Setting the {config_name} active workspace "
f"to '{active_workspace.name}'."
)
active_stack: StackResponse
# Sanitize the active stack
if active_stack_id:
# Ensure that the active stack is still valid
try:
active_stack = self.get_stack(stack_id=active_stack_id)
except KeyError:
logger.warning(
"The current %s active stack is no longer available. "
"Resetting the active stack to default.",
config_name,
)
active_stack = self._get_default_stack(
workspace_id=active_workspace.id
)
else:
if active_stack.workspace.id != active_workspace.id:
logger.warning(
"The current %s active stack is not part of the active "
"workspace. Resetting the active stack to default.",
config_name,
)
active_stack = self._get_default_stack(
workspace_id=active_workspace.id
)
else:
logger.warning(
"Setting the %s active stack to default.",
config_name,
)
active_stack = self._get_default_stack(
workspace_id=active_workspace.id
)
return active_workspace, active_stack
def get_store_info(self) -> ServerModel:
"""Get information about the store.
Returns:
Information about the store.
"""
from zenml.zen_stores.sql_zen_store import SqlZenStore
server_config = ServerConfiguration.get_server_config()
deployment_type = server_config.deployment_type
auth_scheme = server_config.auth_scheme
metadata = server_config.metadata
secrets_store_type = SecretsStoreType.NONE
if isinstance(self, SqlZenStore) and self.config.secrets_store:
secrets_store_type = self.config.secrets_store.type
use_legacy_dashboard = server_config.use_legacy_dashboard
return ServerModel(
id=GlobalConfiguration().user_id,
active=True,
version=zenml.__version__,
deployment_type=deployment_type,
database_type=ServerDatabaseType.OTHER,
debug=IS_DEBUG_ENV,
secrets_store_type=secrets_store_type,
auth_scheme=auth_scheme,
server_url=server_config.server_url or "",
dashboard_url=server_config.dashboard_url or "",
analytics_enabled=GlobalConfiguration().analytics_opt_in,
metadata=metadata,
use_legacy_dashboard=use_legacy_dashboard,
)
def is_local_store(self) -> bool:
"""Check if the store is local or connected to a local ZenML server.
Returns:
True if the store is local, False otherwise.
"""
return self.get_store_info().is_local()
# -----------------------------
# Default workspaces and stacks
# -----------------------------
@property
def _default_workspace_name(self) -> str:
"""Get the default workspace name.
Returns:
The default workspace name.
"""
return os.getenv(
ENV_ZENML_DEFAULT_WORKSPACE_NAME, DEFAULT_WORKSPACE_NAME
)
def _get_default_workspace(self) -> WorkspaceResponse:
"""Get the default workspace.
Raises:
KeyError: If the default workspace doesn't exist.
Returns:
The default workspace.
"""
try:
return self.get_workspace(self._default_workspace_name)
except KeyError:
raise KeyError("Unable to find default workspace.")
def _get_default_stack(
self,
workspace_id: UUID,
) -> StackResponse:
"""Get the default stack for a user in a workspace.
Args:
workspace_id: ID of the workspace.
Returns:
The default stack in the workspace.
Raises:
KeyError: if the workspace or default stack doesn't exist.
"""
default_stacks = self.list_stacks(
StackFilter(
workspace_id=workspace_id,
name=DEFAULT_STACK_AND_COMPONENT_NAME,
)
)
if default_stacks.total == 0:
raise KeyError(
f"No default stack found in workspace {workspace_id}."
)
return default_stacks.items[0]
def get_external_user(self, user_id: UUID) -> UserResponse:
"""Get a user by external ID.
Args:
user_id: The external ID of the user.
Returns:
The user with the supplied external ID.
Raises:
KeyError: If the user doesn't exist.
"""
users = self.list_users(UserFilter(external_user_id=user_id))
if users.total == 0:
raise KeyError(f"User with external ID '{user_id}' not found.")
return users.items[0]
model_config = ConfigDict(
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment=True,
# Ignore extra attributes from configs of previous ZenML versions
extra="ignore",
)
type: StoreType
property
readonly
The type of the store.
Returns:
Type | Description |
---|---|
StoreType |
The type of the store. |
url: str
property
readonly
The URL of the store.
Returns:
Type | Description |
---|---|
str |
The URL of the store. |
__init__(self, skip_default_registrations=False, **kwargs)
special
Create and initialize a store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_default_registrations |
bool |
If |
False |
**kwargs |
Any |
Additional keyword arguments to pass to the Pydantic constructor. |
{} |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the store cannot be initialized. |
AuthorizationException |
If the store cannot be initialized due to authentication errors. |
Source code in zenml/zen_stores/base_zen_store.py
def __init__(
self,
skip_default_registrations: bool = False,
**kwargs: Any,
) -> None:
"""Create and initialize a store.
Args:
skip_default_registrations: If `True`, the creation of the default
stack and user in the store will be skipped.
**kwargs: Additional keyword arguments to pass to the Pydantic
constructor.
Raises:
RuntimeError: If the store cannot be initialized.
AuthorizationException: If the store cannot be initialized due to
authentication errors.
"""
super().__init__(**kwargs)
try:
self._initialize()
# Handle cases where the ZenML server is not available
except ConnectionError as e:
error_message = (
"Cannot connect to the ZenML database because the ZenML server "
f"at {self.url} is not running."
)
if urlparse(self.url).hostname in ["localhost", "127.0.0.1"]:
recommendation = (
"Please run `zenml down` and `zenml up` to restart the "
"server."
)
else:
recommendation = (
"Please run `zenml disconnect` and `zenml connect --url "
f"{self.url}` to reconnect to the server."
)
raise RuntimeError(f"{error_message}\n{recommendation}") from e
except AuthorizationException as e:
raise AuthorizationException(
f"Authorization failed for store at '{self.url}'. Please check "
f"your credentials: {str(e)}"
)
except Exception as e:
zenml_pro_extra = ""
if ".zenml.io" in self.url:
zenml_pro_extra = (
ZENML_PRO_CONNECTION_ISSUES_SUSPENDED_PAUSED_TENANT_HINT
)
raise RuntimeError(
f"Error initializing {self.type.value} store with URL "
f"'{self.url}': {str(e)}" + zenml_pro_extra
) from e
if not skip_default_registrations:
logger.debug("Initializing database")
self._initialize_database()
else:
logger.debug("Skipping database initialization")
convert_config(data, validation_info)
classmethod
Wrapper method to handle the raw data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cls |
the class handler |
required | |
data |
Any |
the raw input data |
required |
validation_info |
ValidationInfo |
the context of the validation. |
required |
Returns:
Type | Description |
---|---|
Any |
the validated data |
Source code in zenml/zen_stores/base_zen_store.py
def before_validator(
cls: Type[BaseModel], data: Any, validation_info: ValidationInfo
) -> Any:
"""Wrapper method to handle the raw data.
Args:
cls: the class handler
data: the raw input data
validation_info: the context of the validation.
Returns:
the validated data
"""
data = model_validator_data_handler(
raw_data=data, base_class=cls, validation_info=validation_info
)
return method(cls=cls, data=data)
create_store(config, skip_default_registrations=False, **kwargs)
staticmethod
Create and initialize a store from a store configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
StoreConfiguration |
The store configuration to use. |
required |
skip_default_registrations |
bool |
If |
False |
**kwargs |
Any |
Additional keyword arguments to pass to the store class |
{} |
Returns:
Type | Description |
---|---|
BaseZenStore |
The initialized store. |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
def create_store(
config: StoreConfiguration,
skip_default_registrations: bool = False,
**kwargs: Any,
) -> "BaseZenStore":
"""Create and initialize a store from a store configuration.
Args:
config: The store configuration to use.
skip_default_registrations: If `True`, the creation of the default
stack and user in the store will be skipped.
**kwargs: Additional keyword arguments to pass to the store class
Returns:
The initialized store.
"""
logger.debug(f"Creating store with config '{config}'...")
store_class = BaseZenStore.get_store_class(config.type)
store = store_class(
config=config,
skip_default_registrations=skip_default_registrations,
**kwargs,
)
return store
get_default_store_config(path)
staticmethod
Get the default store configuration.
The default store is a SQLite store that saves the DB contents on the local filesystem.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
The local path where the store DB will be stored. |
required |
Returns:
Type | Description |
---|---|
StoreConfiguration |
The default store configuration. |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
def get_default_store_config(path: str) -> StoreConfiguration:
"""Get the default store configuration.
The default store is a SQLite store that saves the DB contents on the
local filesystem.
Args:
path: The local path where the store DB will be stored.
Returns:
The default store configuration.
"""
from zenml.zen_stores.sql_zen_store import SqlZenStoreConfiguration
config = SqlZenStoreConfiguration(
type=StoreType.SQL,
url=SqlZenStoreConfiguration.get_local_url(path),
secrets_store=SqlSecretsStoreConfiguration(
type=SecretsStoreType.SQL,
),
)
return config
get_external_user(self, user_id)
Get a user by external ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_id |
UUID |
The external ID of the user. |
required |
Returns:
Type | Description |
---|---|
UserResponse |
The user with the supplied external ID. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the user doesn't exist. |
Source code in zenml/zen_stores/base_zen_store.py
def get_external_user(self, user_id: UUID) -> UserResponse:
"""Get a user by external ID.
Args:
user_id: The external ID of the user.
Returns:
The user with the supplied external ID.
Raises:
KeyError: If the user doesn't exist.
"""
users = self.list_users(UserFilter(external_user_id=user_id))
if users.total == 0:
raise KeyError(f"User with external ID '{user_id}' not found.")
return users.items[0]
get_store_class(store_type)
staticmethod
Returns the class of the given store type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
store_type |
StoreType |
The type of the store to get the class for. |
required |
Returns:
Type | Description |
---|---|
Type[BaseZenStore] |
The class of the given store type or None if the type is unknown. |
Exceptions:
Type | Description |
---|---|
TypeError |
If the store type is unsupported. |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
def get_store_class(store_type: StoreType) -> Type["BaseZenStore"]:
"""Returns the class of the given store type.
Args:
store_type: The type of the store to get the class for.
Returns:
The class of the given store type or None if the type is unknown.
Raises:
TypeError: If the store type is unsupported.
"""
if store_type == StoreType.SQL:
from zenml.zen_stores.sql_zen_store import SqlZenStore
return SqlZenStore
elif store_type == StoreType.REST:
from zenml.zen_stores.rest_zen_store import RestZenStore
return RestZenStore
else:
raise TypeError(
f"No store implementation found for store type "
f"`{store_type.value}`."
)
get_store_config_class(store_type)
staticmethod
Returns the store config class of the given store type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
store_type |
StoreType |
The type of the store to get the class for. |
required |
Returns:
Type | Description |
---|---|
Type[StoreConfiguration] |
The config class of the given store type. |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
def get_store_config_class(
store_type: StoreType,
) -> Type["StoreConfiguration"]:
"""Returns the store config class of the given store type.
Args:
store_type: The type of the store to get the class for.
Returns:
The config class of the given store type.
"""
store_class = BaseZenStore.get_store_class(store_type)
return store_class.CONFIG_TYPE
get_store_info(self)
Get information about the store.
Returns:
Type | Description |
---|---|
ServerModel |
Information about the store. |
Source code in zenml/zen_stores/base_zen_store.py
def get_store_info(self) -> ServerModel:
"""Get information about the store.
Returns:
Information about the store.
"""
from zenml.zen_stores.sql_zen_store import SqlZenStore
server_config = ServerConfiguration.get_server_config()
deployment_type = server_config.deployment_type
auth_scheme = server_config.auth_scheme
metadata = server_config.metadata
secrets_store_type = SecretsStoreType.NONE
if isinstance(self, SqlZenStore) and self.config.secrets_store:
secrets_store_type = self.config.secrets_store.type
use_legacy_dashboard = server_config.use_legacy_dashboard
return ServerModel(
id=GlobalConfiguration().user_id,
active=True,
version=zenml.__version__,
deployment_type=deployment_type,
database_type=ServerDatabaseType.OTHER,
debug=IS_DEBUG_ENV,
secrets_store_type=secrets_store_type,
auth_scheme=auth_scheme,
server_url=server_config.server_url or "",
dashboard_url=server_config.dashboard_url or "",
analytics_enabled=GlobalConfiguration().analytics_opt_in,
metadata=metadata,
use_legacy_dashboard=use_legacy_dashboard,
)
get_store_type(url)
staticmethod
Returns the store type associated with a URL schema.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The store URL. |
required |
Returns:
Type | Description |
---|---|
StoreType |
The store type associated with the supplied URL schema. |
Exceptions:
Type | Description |
---|---|
TypeError |
If no store type was found to support the supplied URL. |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
def get_store_type(url: str) -> StoreType:
"""Returns the store type associated with a URL schema.
Args:
url: The store URL.
Returns:
The store type associated with the supplied URL schema.
Raises:
TypeError: If no store type was found to support the supplied URL.
"""
from zenml.zen_stores.rest_zen_store import RestZenStoreConfiguration
from zenml.zen_stores.sql_zen_store import SqlZenStoreConfiguration
if SqlZenStoreConfiguration.supports_url_scheme(url):
return StoreType.SQL
elif RestZenStoreConfiguration.supports_url_scheme(url):
return StoreType.REST
else:
raise TypeError(f"No store implementation found for URL: {url}.")
is_local_store(self)
Check if the store is local or connected to a local ZenML server.
Returns:
Type | Description |
---|---|
bool |
True if the store is local, False otherwise. |
Source code in zenml/zen_stores/base_zen_store.py
def is_local_store(self) -> bool:
"""Check if the store is local or connected to a local ZenML server.
Returns:
True if the store is local, False otherwise.
"""
return self.get_store_info().is_local()
validate_active_config(self, active_workspace_name_or_id=None, active_stack_id=None, config_name='')
Validate the active configuration.
Call this method to validate the supplied active workspace and active stack values.
This method is guaranteed to return valid workspace ID and stack ID values. If the supplied workspace and stack are not set or are not valid (e.g. they do not exist or are not accessible), the default workspace and default workspace stack will be returned in their stead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
active_workspace_name_or_id |
Union[str, uuid.UUID] |
The name or ID of the active workspace. |
None |
active_stack_id |
Optional[uuid.UUID] |
The ID of the active stack. |
None |
config_name |
str |
The name of the configuration to validate (used in the displayed logs/messages). |
'' |
Returns:
Type | Description |
---|---|
Tuple[zenml.models.v2.core.workspace.WorkspaceResponse, zenml.models.v2.core.stack.StackResponse] |
A tuple containing the active workspace and active stack. |
Source code in zenml/zen_stores/base_zen_store.py
def validate_active_config(
self,
active_workspace_name_or_id: Optional[Union[str, UUID]] = None,
active_stack_id: Optional[UUID] = None,
config_name: str = "",
) -> Tuple[WorkspaceResponse, StackResponse]:
"""Validate the active configuration.
Call this method to validate the supplied active workspace and active
stack values.
This method is guaranteed to return valid workspace ID and stack ID
values. If the supplied workspace and stack are not set or are not valid
(e.g. they do not exist or are not accessible), the default workspace and
default workspace stack will be returned in their stead.
Args:
active_workspace_name_or_id: The name or ID of the active workspace.
active_stack_id: The ID of the active stack.
config_name: The name of the configuration to validate (used in the
displayed logs/messages).
Returns:
A tuple containing the active workspace and active stack.
"""
active_workspace: WorkspaceResponse
if active_workspace_name_or_id:
try:
active_workspace = self.get_workspace(
active_workspace_name_or_id
)
except KeyError:
active_workspace = self._get_default_workspace()
logger.warning(
f"The current {config_name} active workspace is no longer "
f"available. Resetting the active workspace to "
f"'{active_workspace.name}'."
)
else:
active_workspace = self._get_default_workspace()
logger.info(
f"Setting the {config_name} active workspace "
f"to '{active_workspace.name}'."
)
active_stack: StackResponse
# Sanitize the active stack
if active_stack_id:
# Ensure that the active stack is still valid
try:
active_stack = self.get_stack(stack_id=active_stack_id)
except KeyError:
logger.warning(
"The current %s active stack is no longer available. "
"Resetting the active stack to default.",
config_name,
)
active_stack = self._get_default_stack(
workspace_id=active_workspace.id
)
else:
if active_stack.workspace.id != active_workspace.id:
logger.warning(
"The current %s active stack is not part of the active "
"workspace. Resetting the active stack to default.",
config_name,
)
active_stack = self._get_default_stack(
workspace_id=active_workspace.id
)
else:
logger.warning(
"Setting the %s active stack to default.",
config_name,
)
active_stack = self._get_default_stack(
workspace_id=active_workspace.id
)
return active_workspace, active_stack
migrations
special
Alembic database migration utilities.
alembic
Alembic utilities wrapper.
The Alembic class defined here acts as a wrapper around the Alembic library that automatically configures Alembic to use the ZenML SQL store database connection.
Alembic
Alembic environment and migration API.
This class provides a wrapper around the Alembic library that automatically configures Alembic to use the ZenML SQL store database connection.
Source code in zenml/zen_stores/migrations/alembic.py
class Alembic:
"""Alembic environment and migration API.
This class provides a wrapper around the Alembic library that automatically
configures Alembic to use the ZenML SQL store database connection.
"""
def __init__(
self,
engine: Engine,
metadata: MetaData = SQLModel.metadata,
context: Optional[EnvironmentContext] = None,
**kwargs: Any,
) -> None:
"""Initialize the Alembic wrapper.
Args:
engine: The SQLAlchemy engine to use.
metadata: The SQLAlchemy metadata to use.
context: The Alembic environment context to use. If not set, a new
context is created pointing to the ZenML migrations directory.
**kwargs: Additional keyword arguments to pass to the Alembic
environment context.
"""
self.engine = engine
self.metadata = metadata
self.context_kwargs = kwargs
self.config = Config()
self.config.set_main_option(
"script_location", str(Path(__file__).parent)
)
self.script_directory = ScriptDirectory.from_config(self.config)
if context is None:
self.environment_context = EnvironmentContext(
self.config, self.script_directory
)
else:
self.environment_context = context
def db_is_empty(self) -> bool:
"""Check if the database is empty.
Returns:
True if the database is empty, False otherwise.
"""
# Check the existence of any of the SQLModel tables
return not self.engine.dialect.has_table(
self.engine.connect(), schemas.StackSchema.__tablename__
)
def run_migrations(
self,
fn: Optional[Callable[[_RevIdType, MigrationContext], List[Any]]],
) -> None:
"""Run an online migration function in the current migration context.
Args:
fn: Migration function to run. If not set, the function configured
externally by the Alembic CLI command is used.
"""
fn_context_args: Dict[Any, Any] = {}
if fn is not None:
fn_context_args["fn"] = fn
with self.engine.connect() as connection:
self.environment_context.configure(
connection=connection,
target_metadata=self.metadata,
include_object=include_object,
compare_type=True,
render_as_batch=True,
**fn_context_args,
**self.context_kwargs,
)
with self.environment_context.begin_transaction():
self.environment_context.run_migrations()
def head_revisions(self) -> List[str]:
"""Get the head database revisions.
Returns:
List of head revisions.
"""
head_revisions: List[str] = []
def do_get_head_rev(rev: _RevIdType, context: Any) -> List[Any]:
nonlocal head_revisions
for r in self.script_directory.get_heads():
if r is None:
continue
head_revisions.append(r)
return []
self.run_migrations(do_get_head_rev)
return head_revisions
def current_revisions(self) -> List[str]:
"""Get the current database revisions.
Returns:
List of head revisions.
"""
current_revisions: List[str] = []
def do_get_current_rev(rev: _RevIdType, context: Any) -> List[Any]:
nonlocal current_revisions
for r in self.script_directory.get_all_current(
rev # type:ignore [arg-type]
):
if r is None:
continue
current_revisions.append(r.revision)
return []
self.run_migrations(do_get_current_rev)
return current_revisions
def stamp(self, revision: str) -> None:
"""Stamp the revision table with the given revision without running any migrations.
Args:
revision: String revision target.
"""
def do_stamp(rev: _RevIdType, context: Any) -> List[Any]:
return self.script_directory._stamp_revs(revision, rev)
self.run_migrations(do_stamp)
def upgrade(self, revision: str = "heads") -> None:
"""Upgrade the database to a later version.
Args:
revision: String revision target.
"""
def do_upgrade(rev: _RevIdType, context: Any) -> List[Any]:
return self.script_directory._upgrade_revs(
revision,
rev, # type:ignore [arg-type]
)
self.run_migrations(do_upgrade)
def downgrade(self, revision: str) -> None:
"""Revert the database to a previous version.
Args:
revision: String revision target.
"""
def do_downgrade(rev: _RevIdType, context: Any) -> List[Any]:
return self.script_directory._downgrade_revs(
revision,
rev, # type:ignore [arg-type]
)
self.run_migrations(do_downgrade)
__init__(self, engine, metadata=MetaData(), context=None, **kwargs)
special
Initialize the Alembic wrapper.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
engine |
Engine |
The SQLAlchemy engine to use. |
required |
metadata |
MetaData |
The SQLAlchemy metadata to use. |
MetaData() |
context |
Optional[alembic.runtime.environment.EnvironmentContext] |
The Alembic environment context to use. If not set, a new context is created pointing to the ZenML migrations directory. |
None |
**kwargs |
Any |
Additional keyword arguments to pass to the Alembic environment context. |
{} |
Source code in zenml/zen_stores/migrations/alembic.py
def __init__(
self,
engine: Engine,
metadata: MetaData = SQLModel.metadata,
context: Optional[EnvironmentContext] = None,
**kwargs: Any,
) -> None:
"""Initialize the Alembic wrapper.
Args:
engine: The SQLAlchemy engine to use.
metadata: The SQLAlchemy metadata to use.
context: The Alembic environment context to use. If not set, a new
context is created pointing to the ZenML migrations directory.
**kwargs: Additional keyword arguments to pass to the Alembic
environment context.
"""
self.engine = engine
self.metadata = metadata
self.context_kwargs = kwargs
self.config = Config()
self.config.set_main_option(
"script_location", str(Path(__file__).parent)
)
self.script_directory = ScriptDirectory.from_config(self.config)
if context is None:
self.environment_context = EnvironmentContext(
self.config, self.script_directory
)
else:
self.environment_context = context
current_revisions(self)
Get the current database revisions.
Returns:
Type | Description |
---|---|
List[str] |
List of head revisions. |
Source code in zenml/zen_stores/migrations/alembic.py
def current_revisions(self) -> List[str]:
"""Get the current database revisions.
Returns:
List of head revisions.
"""
current_revisions: List[str] = []
def do_get_current_rev(rev: _RevIdType, context: Any) -> List[Any]:
nonlocal current_revisions
for r in self.script_directory.get_all_current(
rev # type:ignore [arg-type]
):
if r is None:
continue
current_revisions.append(r.revision)
return []
self.run_migrations(do_get_current_rev)
return current_revisions
db_is_empty(self)
Check if the database is empty.
Returns:
Type | Description |
---|---|
bool |
True if the database is empty, False otherwise. |
Source code in zenml/zen_stores/migrations/alembic.py
def db_is_empty(self) -> bool:
"""Check if the database is empty.
Returns:
True if the database is empty, False otherwise.
"""
# Check the existence of any of the SQLModel tables
return not self.engine.dialect.has_table(
self.engine.connect(), schemas.StackSchema.__tablename__
)
downgrade(self, revision)
Revert the database to a previous version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
revision |
str |
String revision target. |
required |
Source code in zenml/zen_stores/migrations/alembic.py
def downgrade(self, revision: str) -> None:
"""Revert the database to a previous version.
Args:
revision: String revision target.
"""
def do_downgrade(rev: _RevIdType, context: Any) -> List[Any]:
return self.script_directory._downgrade_revs(
revision,
rev, # type:ignore [arg-type]
)
self.run_migrations(do_downgrade)
head_revisions(self)
Get the head database revisions.
Returns:
Type | Description |
---|---|
List[str] |
List of head revisions. |
Source code in zenml/zen_stores/migrations/alembic.py
def head_revisions(self) -> List[str]:
"""Get the head database revisions.
Returns:
List of head revisions.
"""
head_revisions: List[str] = []
def do_get_head_rev(rev: _RevIdType, context: Any) -> List[Any]:
nonlocal head_revisions
for r in self.script_directory.get_heads():
if r is None:
continue
head_revisions.append(r)
return []
self.run_migrations(do_get_head_rev)
return head_revisions
run_migrations(self, fn)
Run an online migration function in the current migration context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn |
Optional[Callable[[Union[str, Sequence[str]], alembic.runtime.migration.MigrationContext], List[Any]]] |
Migration function to run. If not set, the function configured externally by the Alembic CLI command is used. |
required |
Source code in zenml/zen_stores/migrations/alembic.py
def run_migrations(
self,
fn: Optional[Callable[[_RevIdType, MigrationContext], List[Any]]],
) -> None:
"""Run an online migration function in the current migration context.
Args:
fn: Migration function to run. If not set, the function configured
externally by the Alembic CLI command is used.
"""
fn_context_args: Dict[Any, Any] = {}
if fn is not None:
fn_context_args["fn"] = fn
with self.engine.connect() as connection:
self.environment_context.configure(
connection=connection,
target_metadata=self.metadata,
include_object=include_object,
compare_type=True,
render_as_batch=True,
**fn_context_args,
**self.context_kwargs,
)
with self.environment_context.begin_transaction():
self.environment_context.run_migrations()
stamp(self, revision)
Stamp the revision table with the given revision without running any migrations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
revision |
str |
String revision target. |
required |
Source code in zenml/zen_stores/migrations/alembic.py
def stamp(self, revision: str) -> None:
"""Stamp the revision table with the given revision without running any migrations.
Args:
revision: String revision target.
"""
def do_stamp(rev: _RevIdType, context: Any) -> List[Any]:
return self.script_directory._stamp_revs(revision, rev)
self.run_migrations(do_stamp)
upgrade(self, revision='heads')
Upgrade the database to a later version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
revision |
str |
String revision target. |
'heads' |
Source code in zenml/zen_stores/migrations/alembic.py
def upgrade(self, revision: str = "heads") -> None:
"""Upgrade the database to a later version.
Args:
revision: String revision target.
"""
def do_upgrade(rev: _RevIdType, context: Any) -> List[Any]:
return self.script_directory._upgrade_revs(
revision,
rev, # type:ignore [arg-type]
)
self.run_migrations(do_upgrade)
AlembicVersion (Base)
Alembic version table.
Source code in zenml/zen_stores/migrations/alembic.py
class AlembicVersion(Base): # type: ignore[valid-type,misc]
"""Alembic version table."""
__tablename__ = "alembic_version"
version_num = Column(String, nullable=False, primary_key=True)
include_object(object, name, type_, *args, **kwargs)
Function used to exclude tables from the migration scripts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
Any |
The schema item object to check. |
required |
name |
str |
The name of the object to check. |
required |
type_ |
str |
The type of the object to check. |
required |
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
bool |
True if the object should be included, False otherwise. |
Source code in zenml/zen_stores/migrations/alembic.py
def include_object(
object: Any, name: str, type_: str, *args: Any, **kwargs: Any
) -> bool:
"""Function used to exclude tables from the migration scripts.
Args:
object: The schema item object to check.
name: The name of the object to check.
type_: The type of the object to check.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
True if the object should be included, False otherwise.
"""
return not (type_ == "table" and name in exclude_tables)
utils
ZenML database migration, backup and recovery utilities.
MigrationUtils (BaseModel)
Utilities for database migration, backup and recovery.
Source code in zenml/zen_stores/migrations/utils.py
class MigrationUtils(BaseModel):
"""Utilities for database migration, backup and recovery."""
url: URL
connect_args: Dict[str, Any]
engine_args: Dict[str, Any]
_engine: Optional[Engine] = None
_master_engine: Optional[Engine] = None
def create_engine(self, database: Optional[str] = None) -> Engine:
"""Get the SQLAlchemy engine for a database.
Args:
database: The name of the database. If not set, a master engine
will be returned.
Returns:
The SQLAlchemy engine.
"""
url = self.url._replace(database=database)
return create_engine(
url=url,
connect_args=self.connect_args,
**self.engine_args,
)
@property
def engine(self) -> Engine:
"""The SQLAlchemy engine.
Returns:
The SQLAlchemy engine.
"""
if self._engine is None:
self._engine = self.create_engine(database=self.url.database)
return self._engine
@property
def master_engine(self) -> Engine:
"""The SQLAlchemy engine for the master database.
Returns:
The SQLAlchemy engine for the master database.
"""
if self._master_engine is None:
self._master_engine = self.create_engine()
return self._master_engine
@classmethod
def is_mysql_missing_database_error(cls, error: OperationalError) -> bool:
"""Checks if the given error is due to a missing database.
Args:
error: The error to check.
Returns:
If the error because the MySQL database doesn't exist.
"""
from pymysql.constants.ER import BAD_DB_ERROR
if not isinstance(error.orig, pymysql.err.OperationalError):
return False
error_code = cast(int, error.orig.args[0])
return error_code == BAD_DB_ERROR
def database_exists(
self,
database: Optional[str] = None,
) -> bool:
"""Check if a database exists.
Args:
database: The name of the database to check. If not set, the
database name from the configuration will be used.
Returns:
Whether the database exists.
Raises:
OperationalError: If connecting to the database failed.
"""
database = database or self.url.database
engine = self.create_engine(database=database)
try:
engine.connect()
except OperationalError as e:
if self.is_mysql_missing_database_error(e):
return False
else:
logger.exception(
f"Failed to connect to mysql database `{database}`.",
)
raise
else:
return True
def drop_database(
self,
database: Optional[str] = None,
) -> None:
"""Drops a mysql database.
Args:
database: The name of the database to drop. If not set, the
database name from the configuration will be used.
"""
database = database or self.url.database
with self.master_engine.connect() as conn:
# drop the database if it exists
logger.info(f"Dropping database '{database}'")
conn.execute(text(f"DROP DATABASE IF EXISTS `{database}`"))
def create_database(
self,
database: Optional[str] = None,
drop: bool = False,
) -> None:
"""Creates a mysql database.
Args:
database: The name of the database to create. If not set, the
database name from the configuration will be used.
drop: Whether to drop the database if it already exists.
"""
database = database or self.url.database
if drop:
self.drop_database(database=database)
with self.master_engine.connect() as conn:
logger.info(f"Creating database '{database}'")
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
def backup_database_to_storage(
self, store_db_info: Callable[[Dict[str, Any]], None]
) -> None:
"""Backup the database to a storage location.
Backup the database to an abstract storage location. The storage
location is specified by a function that is called repeatedly to
store the database information. The function is called with a single
argument, which is a dictionary containing either the table schema or
table data. The dictionary contains the following keys:
* `table`: The name of the table.
* `create_stmt`: The table creation statement.
* `data`: A list of rows in the table.
Args:
store_db_info: The function to call to store the database
information.
"""
metadata = MetaData()
metadata.reflect(bind=self.engine)
with self.engine.connect() as conn:
for table in metadata.sorted_tables:
# 1. extract the table creation statements
create_table_construct = CreateTable(table)
create_table_stmt = str(create_table_construct).strip()
for column in create_table_construct.columns:
# enclosing all column names in backticks. This is because
# some column names are reserved keywords in MySQL. For
# example, keys and values. So, instead of tracking all
# keywords, we just enclose all column names in backticks.
# enclose the first word in the column definition in
# backticks
words = str(column).split()
words[0] = f"`{words[0]}`"
create_table_stmt = create_table_stmt.replace(
f"\n\t{str(column)}", " ".join(words)
)
# if any double quotes are used for column names, replace them
# with backticks
create_table_stmt = create_table_stmt.replace('"', "") + ";"
# enclose all table names in backticks. This is because some
# table names are reserved keywords in MySQL (e.g key
# and trigger).
create_table_stmt = create_table_stmt.replace(
f"CREATE TABLE {table.name}",
f"CREATE TABLE `{table.name}`",
)
# do the same for references to other tables
# (i.e. foreign key constraints) by replacing REFERENCES <word>
# with REFERENCES `<word>`
# use a regular expression for this
create_table_stmt = re.sub(
r"REFERENCES\s+(\w+)",
r"REFERENCES `\1`",
create_table_stmt,
)
# In SQLAlchemy, the CreateTable statement may not always
# include unique constraints explicitly if they are implemented
# as unique indexes instead. To make sure we get all unique
# constraints, including those implemented as indexes, we
# extract the unique constraints from the table schema and add
# them to the create table statement.
# Extract the unique constraints from the table schema
unique_constraints = []
for index in table.indexes:
if index.unique:
unique_columns = [
f"`{column.name}`" for column in index.columns
]
unique_constraints.append(
f"UNIQUE KEY `{index.name}` ({', '.join(unique_columns)})"
)
# Add the unique constraints to the create table statement
if unique_constraints:
# Remove the closing parenthesis, semicolon and any
# whitespaces at the end of the create table statement
create_table_stmt = re.sub(
r"\s*\)\s*;\s*$", "", create_table_stmt
)
create_table_stmt = (
create_table_stmt
+ ", \n\t"
+ ", \n\t".join(unique_constraints)
+ "\n);"
)
# Store the table schema
store_db_info(
dict(table=table.name, create_stmt=create_table_stmt)
)
# 2. extract the table data in batches
# If the table has a `created` column, we use it to sort
# the rows in the table starting with the oldest rows.
# This is to ensure that the rows are inserted in the
# correct order, since some tables have inner foreign key
# constraints.
if "created" in table.columns:
order_by = [table.columns["created"]]
else:
order_by = []
if "id" in table.columns:
# If the table has an `id` column, we also use it to sort
# the rows in the table, even if we already use "created"
# to sort the rows. We need a unique field to sort the rows,
# to break the tie between rows with the same "created"
# date, otherwise the same entry might end up multiple times
# in subsequent pages.
order_by.append(table.columns["id"])
# Fetch the number of rows in the table
row_count = conn.scalar(
select(func.count()).select_from(table)
)
# Fetch the data from the table in batches
if row_count is not None:
batch_size = 50
for i in range(0, row_count, batch_size):
rows = conn.execute(
table.select()
.order_by(*order_by)
.limit(batch_size)
.offset(i)
).fetchall()
store_db_info(
dict(
table=table.name,
data=[row._asdict() for row in rows],
),
)
def restore_database_from_storage(
self, load_db_info: Callable[[], Generator[Dict[str, Any], None, None]]
) -> None:
"""Restore the database from a backup storage location.
Restores the database from an abstract storage location. The storage
location is specified by a function that is called repeatedly to
load the database information from the external storage chunk by chunk.
The function must yield a dictionary containing either the table schema
or table data. The dictionary contains the following keys:
* `table`: The name of the table.
* `create_stmt`: The table creation statement.
* `data`: A list of rows in the table.
The function must return `None` when there is no more data to load.
Args:
load_db_info: The function to call to load the database
information.
"""
# Drop and re-create the primary database
self.create_database(drop=True)
metadata = MetaData()
with self.engine.begin() as connection:
# read the DB information one JSON object at a time
for table_dump in load_db_info():
table_name = table_dump["table"]
if "create_stmt" in table_dump:
# execute the table creation statement
connection.execute(text(table_dump["create_stmt"]))
# Reload the database metadata after creating the table
metadata.reflect(bind=self.engine)
if "data" in table_dump:
# insert the data into the database
table = metadata.tables[table_name]
for row in table_dump["data"]:
# Convert column values to the correct type
for column in table.columns:
# Blob columns are stored as binary strings
if column.type.python_type is bytes and isinstance(
row[column.name], str
):
# Convert the string to bytes
row[column.name] = bytes(
row[column.name], "utf-8"
)
# Insert the rows into the table
connection.execute(
table.insert().values(table_dump["data"])
)
def backup_database_to_file(self, dump_file: str) -> None:
"""Backup the database to a file.
This method dumps the entire database into a JSON file. Instead of
using a SQL dump, we use a proprietary JSON dump because:
* it is (mostly) not dependent on the SQL dialect or database version
* it is safer with respect to SQL injection attacks
* it is easier to read and debug
The JSON file contains a list of JSON objects instead of a single JSON
object, because it allows for buffered reading and writing of the file
and thus reduces the memory footprint. Each JSON object can contain
either schema or data information about a single table. For tables with
a large amount of data, the data is split into multiple JSON objects
with the first object always containing the schema.
The format of the dump is as depicted in the following example:
```json
{
"table": "table1",
"create_stmt": "CREATE TABLE table1 (id INTEGER NOT NULL, "
"name VARCHAR(255), PRIMARY KEY (id))"
}
{
"table": "table1",
"data": [
{
"id": 1,
"name": "foo"
},
{
"id": 1,
"name": "bar"
},
...
]
}
{
"table": "table1",
"data": [
{
"id": 101,
"name": "fee"
},
{
"id": 102,
"name": "bee"
},
...
]
}
```
Args:
dump_file: The path to the dump file.
"""
# create the directory if it does not exist
dump_path = os.path.dirname(os.path.abspath(dump_file))
if not os.path.exists(dump_path):
os.makedirs(dump_path)
if self.url.drivername == "sqlite":
# For a sqlite database, we can just make a copy of the database
# file
assert self.url.database is not None
shutil.copyfile(
self.url.database,
dump_file,
)
return
with open(dump_file, "w") as f:
def json_dump(obj: Dict[str, Any]) -> None:
"""Dump a JSON object to the dump file.
Args:
obj: The JSON object to dump.
"""
# Write the data to the JSON file. Use an encoder that
# can handle datetime, Decimal and other types.
json.dump(
obj,
f,
indent=4,
default=pydantic_encoder,
)
f.write("\n")
# Call the generic backup method with a function that dumps the
# JSON objects to the dump file
self.backup_database_to_storage(json_dump)
logger.debug(f"Database backed up to {dump_file}")
def restore_database_from_file(self, dump_file: str) -> None:
"""Restore the database from a backup dump file.
See the documentation of the `backup_database_to_file` method for
details on the format of the dump file.
Args:
dump_file: The path to the dump file.
Raises:
RuntimeError: If the database cannot be restored successfully.
"""
if not os.path.exists(dump_file):
raise RuntimeError(
f"Database backup file '{dump_file}' does not "
f"exist or is not accessible."
)
if self.url.drivername == "sqlite":
# For a sqlite database, we just overwrite the database file
# with the backup file
assert self.url.database is not None
shutil.copyfile(
dump_file,
self.url.database,
)
return
# read the DB dump file one JSON object at a time
with open(dump_file, "r") as f:
def json_load() -> Generator[Dict[str, Any], None, None]:
"""Generator that loads the JSON objects in the dump file.
Yields:
The loaded JSON objects.
"""
buffer = ""
while True:
chunk = f.readline()
if not chunk:
break
buffer += chunk
if chunk.rstrip() == "}":
yield json.loads(buffer)
buffer = ""
# Call the generic restore method with a function that loads the
# JSON objects from the dump file
self.restore_database_from_storage(json_load)
logger.info(f"Database successfully restored from '{dump_file}'")
def backup_database_to_memory(self) -> List[Dict[str, Any]]:
"""Backup the database in memory.
Returns:
The in-memory representation of the database backup.
Raises:
RuntimeError: If the database cannot be backed up successfully.
"""
if self.url.drivername == "sqlite":
# For a sqlite database, this is not supported.
raise RuntimeError(
"In-memory backup is not supported for sqlite databases."
)
db_dump: List[Dict[str, Any]] = []
def store_in_mem(obj: Dict[str, Any]) -> None:
"""Store a JSON object in the in-memory database backup.
Args:
obj: The JSON object to store.
"""
db_dump.append(obj)
# Call the generic backup method with a function that stores the
# JSON objects in the in-memory database backup
self.backup_database_to_storage(store_in_mem)
logger.debug("Database backed up in memory")
return db_dump
def restore_database_from_memory(
self, db_dump: List[Dict[str, Any]]
) -> None:
"""Restore the database from an in-memory backup.
Args:
db_dump: The in-memory database backup to restore from generated
by the `backup_database_to_memory` method.
Raises:
RuntimeError: If the database cannot be restored successfully.
"""
if self.url.drivername == "sqlite":
# For a sqlite database, this is not supported.
raise RuntimeError(
"In-memory backup is not supported for sqlite databases."
)
def load_from_mem() -> Generator[Dict[str, Any], None, None]:
"""Generator that loads the JSON objects from the in-memory backup.
Yields:
The loaded JSON objects.
"""
for obj in db_dump:
yield obj
# Call the generic restore method with a function that loads the
# JSON objects from the in-memory database backup
self.restore_database_from_storage(load_from_mem)
logger.info("Database successfully restored from memory")
@classmethod
def _copy_database(cls, src_engine: Engine, dst_engine: Engine) -> None:
"""Copy the database from one engine to another.
This method assumes that the destination database exists and is empty.
Args:
src_engine: The source SQLAlchemy engine.
dst_engine: The destination SQLAlchemy engine.
"""
src_metadata = MetaData()
src_metadata.reflect(bind=src_engine)
dst_metadata = MetaData()
dst_metadata.reflect(bind=dst_engine)
# @event.listens_for(src_metadata, "column_reflect")
# def generalize_datatypes(inspector, tablename, column_dict):
# column_dict["type"] = column_dict["type"].as_generic(allow_nulltype=True)
# Create all tables in the target database
for table in src_metadata.sorted_tables:
table.create(bind=dst_engine)
# Refresh target metadata after creating the tables
dst_metadata.clear()
dst_metadata.reflect(bind=dst_engine)
# Copy all data from the source database to the destination database
with src_engine.begin() as src_conn:
with dst_engine.begin() as dst_conn:
for src_table in src_metadata.sorted_tables:
dst_table = dst_metadata.tables[src_table.name]
insert = dst_table.insert()
# If the table has a `created` column, we use it to sort
# the rows in the table starting with the oldest rows.
# This is to ensure that the rows are inserted in the
# correct order, since some tables have inner foreign key
# constraints.
if "created" in src_table.columns:
order_by = [src_table.columns["created"]]
else:
order_by = []
if "id" in src_table.columns:
# If the table has an `id` column, we also use it to
# sort the rows in the table, even if we already use
# "created" to sort the rows. We need a unique field to
# sort the rows, to break the tie between rows with the
# same "created" date, otherwise the same entry might
# end up multiple times in subsequent pages.
order_by.append(src_table.columns["id"])
row_count = src_conn.scalar(
select(func.count()).select_from(src_table)
)
# Copy rows in batches
if row_count is not None:
batch_size = 50
for i in range(0, row_count, batch_size):
rows = src_conn.execute(
src_table.select()
.order_by(*order_by)
.limit(batch_size)
.offset(i)
).fetchall()
dst_conn.execute(
insert, [row._asdict() for row in rows]
)
def backup_database_to_db(self, backup_db_name: str) -> None:
"""Backup the database to a backup database.
Args:
backup_db_name: Backup database name to backup to.
"""
# Re-create the backup database
self.create_database(
database=backup_db_name,
drop=True,
)
backup_engine = self.create_engine(database=backup_db_name)
self._copy_database(self.engine, backup_engine)
logger.debug(
f"Database backed up to the `{backup_db_name}` backup database."
)
def restore_database_from_db(self, backup_db_name: str) -> None:
"""Restore the database from the backup database.
Args:
backup_db_name: Backup database name to restore from.
Raises:
RuntimeError: If the backup database does not exist.
"""
if not self.database_exists(database=backup_db_name):
raise RuntimeError(
f"Backup database `{backup_db_name}` does not exist."
)
backup_engine = self.create_engine(database=backup_db_name)
# Drop and re-create the primary database
self.create_database(
drop=True,
)
self._copy_database(backup_engine, self.engine)
logger.debug(
f"Database restored from the `{backup_db_name}` "
"backup database."
)
model_config = ConfigDict(arbitrary_types_allowed=True)
engine: Engine
property
readonly
The SQLAlchemy engine.
Returns:
Type | Description |
---|---|
Engine |
The SQLAlchemy engine. |
master_engine: Engine
property
readonly
The SQLAlchemy engine for the master database.
Returns:
Type | Description |
---|---|
Engine |
The SQLAlchemy engine for the master database. |
backup_database_to_db(self, backup_db_name)
Backup the database to a backup database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backup_db_name |
str |
Backup database name to backup to. |
required |
Source code in zenml/zen_stores/migrations/utils.py
def backup_database_to_db(self, backup_db_name: str) -> None:
"""Backup the database to a backup database.
Args:
backup_db_name: Backup database name to backup to.
"""
# Re-create the backup database
self.create_database(
database=backup_db_name,
drop=True,
)
backup_engine = self.create_engine(database=backup_db_name)
self._copy_database(self.engine, backup_engine)
logger.debug(
f"Database backed up to the `{backup_db_name}` backup database."
)
backup_database_to_file(self, dump_file)
Backup the database to a file.
This method dumps the entire database into a JSON file. Instead of using a SQL dump, we use a proprietary JSON dump because:
* it is (mostly) not dependent on the SQL dialect or database version
* it is safer with respect to SQL injection attacks
* it is easier to read and debug
The JSON file contains a list of JSON objects instead of a single JSON object, because it allows for buffered reading and writing of the file and thus reduces the memory footprint. Each JSON object can contain either schema or data information about a single table. For tables with a large amount of data, the data is split into multiple JSON objects with the first object always containing the schema.
The format of the dump is as depicted in the following example:
{
"table": "table1",
"create_stmt": "CREATE TABLE table1 (id INTEGER NOT NULL, "
"name VARCHAR(255), PRIMARY KEY (id))"
}
{
"table": "table1",
"data": [
{
"id": 1,
"name": "foo"
},
{
"id": 1,
"name": "bar"
},
...
]
}
{
"table": "table1",
"data": [
{
"id": 101,
"name": "fee"
},
{
"id": 102,
"name": "bee"
},
...
]
}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dump_file |
str |
The path to the dump file. |
required |
Source code in zenml/zen_stores/migrations/utils.py
def backup_database_to_file(self, dump_file: str) -> None:
"""Backup the database to a file.
This method dumps the entire database into a JSON file. Instead of
using a SQL dump, we use a proprietary JSON dump because:
* it is (mostly) not dependent on the SQL dialect or database version
* it is safer with respect to SQL injection attacks
* it is easier to read and debug
The JSON file contains a list of JSON objects instead of a single JSON
object, because it allows for buffered reading and writing of the file
and thus reduces the memory footprint. Each JSON object can contain
either schema or data information about a single table. For tables with
a large amount of data, the data is split into multiple JSON objects
with the first object always containing the schema.
The format of the dump is as depicted in the following example:
```json
{
"table": "table1",
"create_stmt": "CREATE TABLE table1 (id INTEGER NOT NULL, "
"name VARCHAR(255), PRIMARY KEY (id))"
}
{
"table": "table1",
"data": [
{
"id": 1,
"name": "foo"
},
{
"id": 1,
"name": "bar"
},
...
]
}
{
"table": "table1",
"data": [
{
"id": 101,
"name": "fee"
},
{
"id": 102,
"name": "bee"
},
...
]
}
```
Args:
dump_file: The path to the dump file.
"""
# create the directory if it does not exist
dump_path = os.path.dirname(os.path.abspath(dump_file))
if not os.path.exists(dump_path):
os.makedirs(dump_path)
if self.url.drivername == "sqlite":
# For a sqlite database, we can just make a copy of the database
# file
assert self.url.database is not None
shutil.copyfile(
self.url.database,
dump_file,
)
return
with open(dump_file, "w") as f:
def json_dump(obj: Dict[str, Any]) -> None:
"""Dump a JSON object to the dump file.
Args:
obj: The JSON object to dump.
"""
# Write the data to the JSON file. Use an encoder that
# can handle datetime, Decimal and other types.
json.dump(
obj,
f,
indent=4,
default=pydantic_encoder,
)
f.write("\n")
# Call the generic backup method with a function that dumps the
# JSON objects to the dump file
self.backup_database_to_storage(json_dump)
logger.debug(f"Database backed up to {dump_file}")
backup_database_to_memory(self)
Backup the database in memory.
Returns:
Type | Description |
---|---|
List[Dict[str, Any]] |
The in-memory representation of the database backup. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the database cannot be backed up successfully. |
Source code in zenml/zen_stores/migrations/utils.py
def backup_database_to_memory(self) -> List[Dict[str, Any]]:
"""Backup the database in memory.
Returns:
The in-memory representation of the database backup.
Raises:
RuntimeError: If the database cannot be backed up successfully.
"""
if self.url.drivername == "sqlite":
# For a sqlite database, this is not supported.
raise RuntimeError(
"In-memory backup is not supported for sqlite databases."
)
db_dump: List[Dict[str, Any]] = []
def store_in_mem(obj: Dict[str, Any]) -> None:
"""Store a JSON object in the in-memory database backup.
Args:
obj: The JSON object to store.
"""
db_dump.append(obj)
# Call the generic backup method with a function that stores the
# JSON objects in the in-memory database backup
self.backup_database_to_storage(store_in_mem)
logger.debug("Database backed up in memory")
return db_dump
backup_database_to_storage(self, store_db_info)
Backup the database to a storage location.
Backup the database to an abstract storage location. The storage location is specified by a function that is called repeatedly to store the database information. The function is called with a single argument, which is a dictionary containing either the table schema or table data. The dictionary contains the following keys:
* `table`: The name of the table.
* `create_stmt`: The table creation statement.
* `data`: A list of rows in the table.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
store_db_info |
Callable[[Dict[str, Any]], NoneType] |
The function to call to store the database information. |
required |
Source code in zenml/zen_stores/migrations/utils.py
def backup_database_to_storage(
self, store_db_info: Callable[[Dict[str, Any]], None]
) -> None:
"""Backup the database to a storage location.
Backup the database to an abstract storage location. The storage
location is specified by a function that is called repeatedly to
store the database information. The function is called with a single
argument, which is a dictionary containing either the table schema or
table data. The dictionary contains the following keys:
* `table`: The name of the table.
* `create_stmt`: The table creation statement.
* `data`: A list of rows in the table.
Args:
store_db_info: The function to call to store the database
information.
"""
metadata = MetaData()
metadata.reflect(bind=self.engine)
with self.engine.connect() as conn:
for table in metadata.sorted_tables:
# 1. extract the table creation statements
create_table_construct = CreateTable(table)
create_table_stmt = str(create_table_construct).strip()
for column in create_table_construct.columns:
# enclosing all column names in backticks. This is because
# some column names are reserved keywords in MySQL. For
# example, keys and values. So, instead of tracking all
# keywords, we just enclose all column names in backticks.
# enclose the first word in the column definition in
# backticks
words = str(column).split()
words[0] = f"`{words[0]}`"
create_table_stmt = create_table_stmt.replace(
f"\n\t{str(column)}", " ".join(words)
)
# if any double quotes are used for column names, replace them
# with backticks
create_table_stmt = create_table_stmt.replace('"', "") + ";"
# enclose all table names in backticks. This is because some
# table names are reserved keywords in MySQL (e.g key
# and trigger).
create_table_stmt = create_table_stmt.replace(
f"CREATE TABLE {table.name}",
f"CREATE TABLE `{table.name}`",
)
# do the same for references to other tables
# (i.e. foreign key constraints) by replacing REFERENCES <word>
# with REFERENCES `<word>`
# use a regular expression for this
create_table_stmt = re.sub(
r"REFERENCES\s+(\w+)",
r"REFERENCES `\1`",
create_table_stmt,
)
# In SQLAlchemy, the CreateTable statement may not always
# include unique constraints explicitly if they are implemented
# as unique indexes instead. To make sure we get all unique
# constraints, including those implemented as indexes, we
# extract the unique constraints from the table schema and add
# them to the create table statement.
# Extract the unique constraints from the table schema
unique_constraints = []
for index in table.indexes:
if index.unique:
unique_columns = [
f"`{column.name}`" for column in index.columns
]
unique_constraints.append(
f"UNIQUE KEY `{index.name}` ({', '.join(unique_columns)})"
)
# Add the unique constraints to the create table statement
if unique_constraints:
# Remove the closing parenthesis, semicolon and any
# whitespaces at the end of the create table statement
create_table_stmt = re.sub(
r"\s*\)\s*;\s*$", "", create_table_stmt
)
create_table_stmt = (
create_table_stmt
+ ", \n\t"
+ ", \n\t".join(unique_constraints)
+ "\n);"
)
# Store the table schema
store_db_info(
dict(table=table.name, create_stmt=create_table_stmt)
)
# 2. extract the table data in batches
# If the table has a `created` column, we use it to sort
# the rows in the table starting with the oldest rows.
# This is to ensure that the rows are inserted in the
# correct order, since some tables have inner foreign key
# constraints.
if "created" in table.columns:
order_by = [table.columns["created"]]
else:
order_by = []
if "id" in table.columns:
# If the table has an `id` column, we also use it to sort
# the rows in the table, even if we already use "created"
# to sort the rows. We need a unique field to sort the rows,
# to break the tie between rows with the same "created"
# date, otherwise the same entry might end up multiple times
# in subsequent pages.
order_by.append(table.columns["id"])
# Fetch the number of rows in the table
row_count = conn.scalar(
select(func.count()).select_from(table)
)
# Fetch the data from the table in batches
if row_count is not None:
batch_size = 50
for i in range(0, row_count, batch_size):
rows = conn.execute(
table.select()
.order_by(*order_by)
.limit(batch_size)
.offset(i)
).fetchall()
store_db_info(
dict(
table=table.name,
data=[row._asdict() for row in rows],
),
)
create_database(self, database=None, drop=False)
Creates a mysql database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
database |
Optional[str] |
The name of the database to create. If not set, the database name from the configuration will be used. |
None |
drop |
bool |
Whether to drop the database if it already exists. |
False |
Source code in zenml/zen_stores/migrations/utils.py
def create_database(
self,
database: Optional[str] = None,
drop: bool = False,
) -> None:
"""Creates a mysql database.
Args:
database: The name of the database to create. If not set, the
database name from the configuration will be used.
drop: Whether to drop the database if it already exists.
"""
database = database or self.url.database
if drop:
self.drop_database(database=database)
with self.master_engine.connect() as conn:
logger.info(f"Creating database '{database}'")
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`"))
create_engine(self, database=None)
Get the SQLAlchemy engine for a database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
database |
Optional[str] |
The name of the database. If not set, a master engine will be returned. |
None |
Returns:
Type | Description |
---|---|
Engine |
The SQLAlchemy engine. |
Source code in zenml/zen_stores/migrations/utils.py
def create_engine(self, database: Optional[str] = None) -> Engine:
"""Get the SQLAlchemy engine for a database.
Args:
database: The name of the database. If not set, a master engine
will be returned.
Returns:
The SQLAlchemy engine.
"""
url = self.url._replace(database=database)
return create_engine(
url=url,
connect_args=self.connect_args,
**self.engine_args,
)
database_exists(self, database=None)
Check if a database exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
database |
Optional[str] |
The name of the database to check. If not set, the database name from the configuration will be used. |
None |
Returns:
Type | Description |
---|---|
bool |
Whether the database exists. |
Exceptions:
Type | Description |
---|---|
OperationalError |
If connecting to the database failed. |
Source code in zenml/zen_stores/migrations/utils.py
def database_exists(
self,
database: Optional[str] = None,
) -> bool:
"""Check if a database exists.
Args:
database: The name of the database to check. If not set, the
database name from the configuration will be used.
Returns:
Whether the database exists.
Raises:
OperationalError: If connecting to the database failed.
"""
database = database or self.url.database
engine = self.create_engine(database=database)
try:
engine.connect()
except OperationalError as e:
if self.is_mysql_missing_database_error(e):
return False
else:
logger.exception(
f"Failed to connect to mysql database `{database}`.",
)
raise
else:
return True
drop_database(self, database=None)
Drops a mysql database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
database |
Optional[str] |
The name of the database to drop. If not set, the database name from the configuration will be used. |
None |
Source code in zenml/zen_stores/migrations/utils.py
def drop_database(
self,
database: Optional[str] = None,
) -> None:
"""Drops a mysql database.
Args:
database: The name of the database to drop. If not set, the
database name from the configuration will be used.
"""
database = database or self.url.database
with self.master_engine.connect() as conn:
# drop the database if it exists
logger.info(f"Dropping database '{database}'")
conn.execute(text(f"DROP DATABASE IF EXISTS `{database}`"))
is_mysql_missing_database_error(error)
classmethod
Checks if the given error is due to a missing database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
error |
OperationalError |
The error to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the error because the MySQL database doesn't exist. |
Source code in zenml/zen_stores/migrations/utils.py
@classmethod
def is_mysql_missing_database_error(cls, error: OperationalError) -> bool:
"""Checks if the given error is due to a missing database.
Args:
error: The error to check.
Returns:
If the error because the MySQL database doesn't exist.
"""
from pymysql.constants.ER import BAD_DB_ERROR
if not isinstance(error.orig, pymysql.err.OperationalError):
return False
error_code = cast(int, error.orig.args[0])
return error_code == BAD_DB_ERROR
model_post_init(self, __context)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
BaseModel |
The BaseModel instance. |
required |
__context |
Any |
The context. |
required |
Source code in zenml/zen_stores/migrations/utils.py
def init_private_attributes(self: BaseModel, __context: Any) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
__context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
restore_database_from_db(self, backup_db_name)
Restore the database from the backup database.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backup_db_name |
str |
Backup database name to restore from. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the backup database does not exist. |
Source code in zenml/zen_stores/migrations/utils.py
def restore_database_from_db(self, backup_db_name: str) -> None:
"""Restore the database from the backup database.
Args:
backup_db_name: Backup database name to restore from.
Raises:
RuntimeError: If the backup database does not exist.
"""
if not self.database_exists(database=backup_db_name):
raise RuntimeError(
f"Backup database `{backup_db_name}` does not exist."
)
backup_engine = self.create_engine(database=backup_db_name)
# Drop and re-create the primary database
self.create_database(
drop=True,
)
self._copy_database(backup_engine, self.engine)
logger.debug(
f"Database restored from the `{backup_db_name}` "
"backup database."
)
restore_database_from_file(self, dump_file)
Restore the database from a backup dump file.
See the documentation of the backup_database_to_file
method for
details on the format of the dump file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dump_file |
str |
The path to the dump file. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the database cannot be restored successfully. |
Source code in zenml/zen_stores/migrations/utils.py
def restore_database_from_file(self, dump_file: str) -> None:
"""Restore the database from a backup dump file.
See the documentation of the `backup_database_to_file` method for
details on the format of the dump file.
Args:
dump_file: The path to the dump file.
Raises:
RuntimeError: If the database cannot be restored successfully.
"""
if not os.path.exists(dump_file):
raise RuntimeError(
f"Database backup file '{dump_file}' does not "
f"exist or is not accessible."
)
if self.url.drivername == "sqlite":
# For a sqlite database, we just overwrite the database file
# with the backup file
assert self.url.database is not None
shutil.copyfile(
dump_file,
self.url.database,
)
return
# read the DB dump file one JSON object at a time
with open(dump_file, "r") as f:
def json_load() -> Generator[Dict[str, Any], None, None]:
"""Generator that loads the JSON objects in the dump file.
Yields:
The loaded JSON objects.
"""
buffer = ""
while True:
chunk = f.readline()
if not chunk:
break
buffer += chunk
if chunk.rstrip() == "}":
yield json.loads(buffer)
buffer = ""
# Call the generic restore method with a function that loads the
# JSON objects from the dump file
self.restore_database_from_storage(json_load)
logger.info(f"Database successfully restored from '{dump_file}'")
restore_database_from_memory(self, db_dump)
Restore the database from an in-memory backup.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
db_dump |
List[Dict[str, Any]] |
The in-memory database backup to restore from generated
by the |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the database cannot be restored successfully. |
Source code in zenml/zen_stores/migrations/utils.py
def restore_database_from_memory(
self, db_dump: List[Dict[str, Any]]
) -> None:
"""Restore the database from an in-memory backup.
Args:
db_dump: The in-memory database backup to restore from generated
by the `backup_database_to_memory` method.
Raises:
RuntimeError: If the database cannot be restored successfully.
"""
if self.url.drivername == "sqlite":
# For a sqlite database, this is not supported.
raise RuntimeError(
"In-memory backup is not supported for sqlite databases."
)
def load_from_mem() -> Generator[Dict[str, Any], None, None]:
"""Generator that loads the JSON objects from the in-memory backup.
Yields:
The loaded JSON objects.
"""
for obj in db_dump:
yield obj
# Call the generic restore method with a function that loads the
# JSON objects from the in-memory database backup
self.restore_database_from_storage(load_from_mem)
logger.info("Database successfully restored from memory")
restore_database_from_storage(self, load_db_info)
Restore the database from a backup storage location.
Restores the database from an abstract storage location. The storage location is specified by a function that is called repeatedly to load the database information from the external storage chunk by chunk. The function must yield a dictionary containing either the table schema or table data. The dictionary contains the following keys:
* `table`: The name of the table.
* `create_stmt`: The table creation statement.
* `data`: A list of rows in the table.
The function must return None
when there is no more data to load.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
load_db_info |
Callable[[], Generator[Dict[str, Any], NoneType, NoneType]] |
The function to call to load the database information. |
required |
Source code in zenml/zen_stores/migrations/utils.py
def restore_database_from_storage(
self, load_db_info: Callable[[], Generator[Dict[str, Any], None, None]]
) -> None:
"""Restore the database from a backup storage location.
Restores the database from an abstract storage location. The storage
location is specified by a function that is called repeatedly to
load the database information from the external storage chunk by chunk.
The function must yield a dictionary containing either the table schema
or table data. The dictionary contains the following keys:
* `table`: The name of the table.
* `create_stmt`: The table creation statement.
* `data`: A list of rows in the table.
The function must return `None` when there is no more data to load.
Args:
load_db_info: The function to call to load the database
information.
"""
# Drop and re-create the primary database
self.create_database(drop=True)
metadata = MetaData()
with self.engine.begin() as connection:
# read the DB information one JSON object at a time
for table_dump in load_db_info():
table_name = table_dump["table"]
if "create_stmt" in table_dump:
# execute the table creation statement
connection.execute(text(table_dump["create_stmt"]))
# Reload the database metadata after creating the table
metadata.reflect(bind=self.engine)
if "data" in table_dump:
# insert the data into the database
table = metadata.tables[table_name]
for row in table_dump["data"]:
# Convert column values to the correct type
for column in table.columns:
# Blob columns are stored as binary strings
if column.type.python_type is bytes and isinstance(
row[column.name], str
):
# Convert the string to bytes
row[column.name] = bytes(
row[column.name], "utf-8"
)
# Insert the rows into the table
connection.execute(
table.insert().values(table_dump["data"])
)
rest_zen_store
REST Zen Store implementation.
RestZenStore (BaseZenStore)
Store implementation for accessing data from a REST API.
Source code in zenml/zen_stores/rest_zen_store.py
class RestZenStore(BaseZenStore):
"""Store implementation for accessing data from a REST API."""
config: RestZenStoreConfiguration
TYPE: ClassVar[StoreType] = StoreType.REST
CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = RestZenStoreConfiguration
_api_token: Optional[str] = None
_session: Optional[requests.Session] = None
# ====================================
# ZenML Store interface implementation
# ====================================
# --------------------------------
# Initialization and configuration
# --------------------------------
def _initialize(self) -> None:
"""Initialize the REST store."""
client_version = zenml.__version__
server_version = self.get_store_info().version
if not DISABLE_CLIENT_SERVER_MISMATCH_WARNING and (
server_version != client_version
):
logger.warning(
"Your ZenML client version (%s) does not match the server "
"version (%s). This version mismatch might lead to errors or "
"unexpected behavior. \nTo disable this warning message, set "
"the environment variable `%s=True`",
client_version,
server_version,
ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING,
)
def get_store_info(self) -> ServerModel:
"""Get information about the server.
Returns:
Information about the server.
"""
body = self.get(INFO)
return ServerModel.model_validate(body)
def get_deployment_id(self) -> UUID:
"""Get the ID of the deployment.
Returns:
The ID of the deployment.
"""
return self.get_store_info().id
# -------------------- Server Settings --------------------
def get_server_settings(
self, hydrate: bool = True
) -> ServerSettingsResponse:
"""Get the server settings.
Args:
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The server settings.
"""
response_body = self.get(SERVER_SETTINGS, params={"hydrate": hydrate})
return ServerSettingsResponse.model_validate(response_body)
def update_server_settings(
self, settings_update: ServerSettingsUpdate
) -> ServerSettingsResponse:
"""Update the server settings.
Args:
settings_update: The server settings update.
Returns:
The updated server settings.
"""
response_body = self.put(SERVER_SETTINGS, body=settings_update)
return ServerSettingsResponse.model_validate(response_body)
# -------------------- Actions --------------------
def create_action(self, action: ActionRequest) -> ActionResponse:
"""Create an action.
Args:
action: The action to create.
Returns:
The created action.
"""
return self._create_resource(
resource=action,
route=ACTIONS,
response_model=ActionResponse,
)
def get_action(
self,
action_id: UUID,
hydrate: bool = True,
) -> ActionResponse:
"""Get an action by ID.
Args:
action_id: The ID of the action to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The action.
"""
return self._get_resource(
resource_id=action_id,
route=ACTIONS,
response_model=ActionResponse,
params={"hydrate": hydrate},
)
def list_actions(
self,
action_filter_model: ActionFilter,
hydrate: bool = False,
) -> Page[ActionResponse]:
"""List all actions matching the given filter criteria.
Args:
action_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all actions matching the filter criteria.
"""
return self._list_paginated_resources(
route=ACTIONS,
response_model=ActionResponse,
filter_model=action_filter_model,
params={"hydrate": hydrate},
)
def update_action(
self,
action_id: UUID,
action_update: ActionUpdate,
) -> ActionResponse:
"""Update an existing action.
Args:
action_id: The ID of the action to update.
action_update: The update to be applied to the action.
Returns:
The updated action.
"""
return self._update_resource(
resource_id=action_id,
resource_update=action_update,
route=ACTIONS,
response_model=ActionResponse,
)
def delete_action(self, action_id: UUID) -> None:
"""Delete an action.
Args:
action_id: The ID of the action to delete.
"""
self._delete_resource(
resource_id=action_id,
route=ACTIONS,
)
# ----------------------------- API Keys -----------------------------
def create_api_key(
self, service_account_id: UUID, api_key: APIKeyRequest
) -> APIKeyResponse:
"""Create a new API key for a service account.
Args:
service_account_id: The ID of the service account for which to
create the API key.
api_key: The API key to create.
Returns:
The created API key.
"""
return self._create_resource(
resource=api_key,
route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}",
response_model=APIKeyResponse,
)
def get_api_key(
self,
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
hydrate: bool = True,
) -> APIKeyResponse:
"""Get an API key for a service account.
Args:
service_account_id: The ID of the service account for which to fetch
the API key.
api_key_name_or_id: The name or ID of the API key to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The API key with the given ID.
"""
return self._get_resource(
resource_id=api_key_name_or_id,
route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}",
response_model=APIKeyResponse,
params={"hydrate": hydrate},
)
def set_api_key(self, api_key: str) -> None:
"""Set the API key to use for authentication.
Args:
api_key: The API key to use for authentication.
"""
self.config.api_key = api_key
self.clear_session()
# TODO: find a way to persist the API key in the configuration file
# without calling _write_config() here.
# This is the only place where we need to explicitly call
# _write_config() to persist the global configuration.
GlobalConfiguration()._write_config()
def list_api_keys(
self,
service_account_id: UUID,
filter_model: APIKeyFilter,
hydrate: bool = False,
) -> Page[APIKeyResponse]:
"""List all API keys for a service account matching the given filter criteria.
Args:
service_account_id: The ID of the service account for which to list
the API keys.
filter_model: All filter parameters including pagination
params
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all API keys matching the filter criteria.
"""
return self._list_paginated_resources(
route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}",
response_model=APIKeyResponse,
filter_model=filter_model,
params={"hydrate": hydrate},
)
def update_api_key(
self,
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
api_key_update: APIKeyUpdate,
) -> APIKeyResponse:
"""Update an API key for a service account.
Args:
service_account_id: The ID of the service account for which to
update the API key.
api_key_name_or_id: The name or ID of the API key to update.
api_key_update: The update request on the API key.
Returns:
The updated API key.
"""
return self._update_resource(
resource_id=api_key_name_or_id,
resource_update=api_key_update,
route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}",
response_model=APIKeyResponse,
)
def rotate_api_key(
self,
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
rotate_request: APIKeyRotateRequest,
) -> APIKeyResponse:
"""Rotate an API key for a service account.
Args:
service_account_id: The ID of the service account for which to
rotate the API key.
api_key_name_or_id: The name or ID of the API key to rotate.
rotate_request: The rotate request on the API key.
Returns:
The updated API key.
"""
response_body = self.put(
f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}/{str(api_key_name_or_id)}{API_KEY_ROTATE}",
body=rotate_request,
)
return APIKeyResponse.model_validate(response_body)
def delete_api_key(
self,
service_account_id: UUID,
api_key_name_or_id: Union[str, UUID],
) -> None:
"""Delete an API key for a service account.
Args:
service_account_id: The ID of the service account for which to
delete the API key.
api_key_name_or_id: The name or ID of the API key to delete.
"""
self._delete_resource(
resource_id=api_key_name_or_id,
route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}",
)
# ----------------------------- Services -----------------------------
def create_service(
self, service_request: ServiceRequest
) -> ServiceResponse:
"""Create a new service.
Args:
service_request: The service to create.
Returns:
The created service.
"""
return self._create_resource(
resource=service_request,
response_model=ServiceResponse,
route=SERVICES,
)
def get_service(
self, service_id: UUID, hydrate: bool = True
) -> ServiceResponse:
"""Get a service.
Args:
service_id: The ID of the service to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The service.
"""
return self._get_resource(
resource_id=service_id,
route=SERVICES,
response_model=ServiceResponse,
params={"hydrate": hydrate},
)
def list_services(
self, filter_model: ServiceFilter, hydrate: bool = False
) -> Page[ServiceResponse]:
"""List all services matching the given filter criteria.
Args:
filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all services matching the filter criteria.
"""
return self._list_paginated_resources(
route=SERVICES,
response_model=ServiceResponse,
filter_model=filter_model,
params={"hydrate": hydrate},
)
def update_service(
self, service_id: UUID, update: ServiceUpdate
) -> ServiceResponse:
"""Update a service.
Args:
service_id: The ID of the service to update.
update: The update to be applied to the service.
Returns:
The updated service.
"""
return self._update_resource(
resource_id=service_id,
resource_update=update,
response_model=ServiceResponse,
route=SERVICES,
)
def delete_service(self, service_id: UUID) -> None:
"""Delete a service.
Args:
service_id: The ID of the service to delete.
"""
self._delete_resource(resource_id=service_id, route=SERVICES)
# ----------------------------- Artifacts -----------------------------
def create_artifact(self, artifact: ArtifactRequest) -> ArtifactResponse:
"""Creates a new artifact.
Args:
artifact: The artifact to create.
Returns:
The newly created artifact.
"""
return self._create_resource(
resource=artifact,
response_model=ArtifactResponse,
route=ARTIFACTS,
)
def get_artifact(
self, artifact_id: UUID, hydrate: bool = True
) -> ArtifactResponse:
"""Gets an artifact.
Args:
artifact_id: The ID of the artifact to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The artifact.
"""
return self._get_resource(
resource_id=artifact_id,
route=ARTIFACTS,
response_model=ArtifactResponse,
params={"hydrate": hydrate},
)
def list_artifacts(
self, filter_model: ArtifactFilter, hydrate: bool = False
) -> Page[ArtifactResponse]:
"""List all artifacts matching the given filter criteria.
Args:
filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all artifacts matching the filter criteria.
"""
return self._list_paginated_resources(
route=ARTIFACTS,
response_model=ArtifactResponse,
filter_model=filter_model,
params={"hydrate": hydrate},
)
def update_artifact(
self, artifact_id: UUID, artifact_update: ArtifactUpdate
) -> ArtifactResponse:
"""Updates an artifact.
Args:
artifact_id: The ID of the artifact to update.
artifact_update: The update to be applied to the artifact.
Returns:
The updated artifact.
"""
return self._update_resource(
resource_id=artifact_id,
resource_update=artifact_update,
response_model=ArtifactResponse,
route=ARTIFACTS,
)
def delete_artifact(self, artifact_id: UUID) -> None:
"""Deletes an artifact.
Args:
artifact_id: The ID of the artifact to delete.
"""
self._delete_resource(resource_id=artifact_id, route=ARTIFACTS)
# -------------------- Artifact Versions --------------------
def create_artifact_version(
self, artifact_version: ArtifactVersionRequest
) -> ArtifactVersionResponse:
"""Creates an artifact version.
Args:
artifact_version: The artifact version to create.
Returns:
The created artifact version.
"""
return self._create_resource(
resource=artifact_version,
response_model=ArtifactVersionResponse,
route=ARTIFACT_VERSIONS,
)
def get_artifact_version(
self, artifact_version_id: UUID, hydrate: bool = True
) -> ArtifactVersionResponse:
"""Gets an artifact.
Args:
artifact_version_id: The ID of the artifact version to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The artifact version.
"""
return self._get_resource(
resource_id=artifact_version_id,
route=ARTIFACT_VERSIONS,
response_model=ArtifactVersionResponse,
params={"hydrate": hydrate},
)
def list_artifact_versions(
self,
artifact_version_filter_model: ArtifactVersionFilter,
hydrate: bool = False,
) -> Page[ArtifactVersionResponse]:
"""List all artifact versions matching the given filter criteria.
Args:
artifact_version_filter_model: All filter parameters including
pagination params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all artifact versions matching the filter criteria.
"""
return self._list_paginated_resources(
route=ARTIFACT_VERSIONS,
response_model=ArtifactVersionResponse,
filter_model=artifact_version_filter_model,
params={"hydrate": hydrate},
)
def update_artifact_version(
self,
artifact_version_id: UUID,
artifact_version_update: ArtifactVersionUpdate,
) -> ArtifactVersionResponse:
"""Updates an artifact version.
Args:
artifact_version_id: The ID of the artifact version to update.
artifact_version_update: The update to be applied to the artifact
version.
Returns:
The updated artifact version.
"""
return self._update_resource(
resource_id=artifact_version_id,
resource_update=artifact_version_update,
response_model=ArtifactVersionResponse,
route=ARTIFACT_VERSIONS,
)
def delete_artifact_version(self, artifact_version_id: UUID) -> None:
"""Deletes an artifact version.
Args:
artifact_version_id: The ID of the artifact version to delete.
"""
self._delete_resource(
resource_id=artifact_version_id, route=ARTIFACT_VERSIONS
)
def prune_artifact_versions(
self,
only_versions: bool = True,
) -> None:
"""Prunes unused artifact versions and their artifacts.
Args:
only_versions: Only delete artifact versions, keeping artifacts
"""
self.delete(
path=ARTIFACT_VERSIONS, params={"only_versions": only_versions}
)
# ------------------------ Artifact Visualizations ------------------------
def get_artifact_visualization(
self, artifact_visualization_id: UUID, hydrate: bool = True
) -> ArtifactVisualizationResponse:
"""Gets an artifact visualization.
Args:
artifact_visualization_id: The ID of the artifact visualization to
get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The artifact visualization.
"""
return self._get_resource(
resource_id=artifact_visualization_id,
route=ARTIFACT_VISUALIZATIONS,
response_model=ArtifactVisualizationResponse,
params={"hydrate": hydrate},
)
# ------------------------ Code References ------------------------
def get_code_reference(
self, code_reference_id: UUID, hydrate: bool = True
) -> CodeReferenceResponse:
"""Gets a code reference.
Args:
code_reference_id: The ID of the code reference to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The code reference.
"""
return self._get_resource(
resource_id=code_reference_id,
route=CODE_REFERENCES,
response_model=CodeReferenceResponse,
params={"hydrate": hydrate},
)
# --------------------------- Code Repositories ---------------------------
def create_code_repository(
self, code_repository: CodeRepositoryRequest
) -> CodeRepositoryResponse:
"""Creates a new code repository.
Args:
code_repository: Code repository to be created.
Returns:
The newly created code repository.
"""
return self._create_workspace_scoped_resource(
resource=code_repository,
response_model=CodeRepositoryResponse,
route=CODE_REPOSITORIES,
)
def get_code_repository(
self, code_repository_id: UUID, hydrate: bool = True
) -> CodeRepositoryResponse:
"""Gets a specific code repository.
Args:
code_repository_id: The ID of the code repository to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested code repository, if it was found.
"""
return self._get_resource(
resource_id=code_repository_id,
route=CODE_REPOSITORIES,
response_model=CodeRepositoryResponse,
params={"hydrate": hydrate},
)
def list_code_repositories(
self,
filter_model: CodeRepositoryFilter,
hydrate: bool = False,
) -> Page[CodeRepositoryResponse]:
"""List all code repositories.
Args:
filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all code repositories.
"""
return self._list_paginated_resources(
route=CODE_REPOSITORIES,
response_model=CodeRepositoryResponse,
filter_model=filter_model,
params={"hydrate": hydrate},
)
def update_code_repository(
self, code_repository_id: UUID, update: CodeRepositoryUpdate
) -> CodeRepositoryResponse:
"""Updates an existing code repository.
Args:
code_repository_id: The ID of the code repository to update.
update: The update to be applied to the code repository.
Returns:
The updated code repository.
"""
return self._update_resource(
resource_id=code_repository_id,
resource_update=update,
response_model=CodeRepositoryResponse,
route=CODE_REPOSITORIES,
)
def delete_code_repository(self, code_repository_id: UUID) -> None:
"""Deletes a code repository.
Args:
code_repository_id: The ID of the code repository to delete.
"""
self._delete_resource(
resource_id=code_repository_id, route=CODE_REPOSITORIES
)
# ----------------------------- Components -----------------------------
def create_stack_component(
self,
component: ComponentRequest,
) -> ComponentResponse:
"""Create a stack component.
Args:
component: The stack component to create.
Returns:
The created stack component.
"""
return self._create_workspace_scoped_resource(
resource=component,
route=STACK_COMPONENTS,
response_model=ComponentResponse,
)
def get_stack_component(
self, component_id: UUID, hydrate: bool = True
) -> ComponentResponse:
"""Get a stack component by ID.
Args:
component_id: The ID of the stack component to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The stack component.
"""
return self._get_resource(
resource_id=component_id,
route=STACK_COMPONENTS,
response_model=ComponentResponse,
params={"hydrate": hydrate},
)
def list_stack_components(
self,
component_filter_model: ComponentFilter,
hydrate: bool = False,
) -> Page[ComponentResponse]:
"""List all stack components matching the given filter criteria.
Args:
component_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all stack components matching the filter criteria.
"""
return self._list_paginated_resources(
route=STACK_COMPONENTS,
response_model=ComponentResponse,
filter_model=component_filter_model,
params={"hydrate": hydrate},
)
def update_stack_component(
self,
component_id: UUID,
component_update: ComponentUpdate,
) -> ComponentResponse:
"""Update an existing stack component.
Args:
component_id: The ID of the stack component to update.
component_update: The update to be applied to the stack component.
Returns:
The updated stack component.
"""
return self._update_resource(
resource_id=component_id,
resource_update=component_update,
route=STACK_COMPONENTS,
response_model=ComponentResponse,
)
def delete_stack_component(self, component_id: UUID) -> None:
"""Delete a stack component.
Args:
component_id: The ID of the stack component to delete.
"""
self._delete_resource(
resource_id=component_id,
route=STACK_COMPONENTS,
)
# ----------------------------- Flavors -----------------------------
def create_flavor(self, flavor: FlavorRequest) -> FlavorResponse:
"""Creates a new stack component flavor.
Args:
flavor: The stack component flavor to create.
Returns:
The newly created flavor.
"""
return self._create_resource(
resource=flavor,
route=FLAVORS,
response_model=FlavorResponse,
)
def get_flavor(
self, flavor_id: UUID, hydrate: bool = True
) -> FlavorResponse:
"""Get a stack component flavor by ID.
Args:
flavor_id: The ID of the stack component flavor to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The stack component flavor.
"""
return self._get_resource(
resource_id=flavor_id,
route=FLAVORS,
response_model=FlavorResponse,
params={"hydrate": hydrate},
)
def list_flavors(
self,
flavor_filter_model: FlavorFilter,
hydrate: bool = False,
) -> Page[FlavorResponse]:
"""List all stack component flavors matching the given filter criteria.
Args:
flavor_filter_model: All filter parameters including pagination
params
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
List of all the stack component flavors matching the given criteria.
"""
return self._list_paginated_resources(
route=FLAVORS,
response_model=FlavorResponse,
filter_model=flavor_filter_model,
params={"hydrate": hydrate},
)
def update_flavor(
self, flavor_id: UUID, flavor_update: FlavorUpdate
) -> FlavorResponse:
"""Updates an existing user.
Args:
flavor_id: The id of the flavor to update.
flavor_update: The update to be applied to the flavor.
Returns:
The updated flavor.
"""
return self._update_resource(
resource_id=flavor_id,
resource_update=flavor_update,
route=FLAVORS,
response_model=FlavorResponse,
)
def delete_flavor(self, flavor_id: UUID) -> None:
"""Delete a stack component flavor.
Args:
flavor_id: The ID of the stack component flavor to delete.
"""
self._delete_resource(
resource_id=flavor_id,
route=FLAVORS,
)
# ------------------------ Logs ------------------------
def get_logs(self, logs_id: UUID, hydrate: bool = True) -> LogsResponse:
"""Gets logs with the given ID.
Args:
logs_id: The ID of the logs to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The logs.
"""
return self._get_resource(
resource_id=logs_id,
route=LOGS,
response_model=LogsResponse,
params={"hydrate": hydrate},
)
# ----------------------------- Pipelines -----------------------------
def create_pipeline(self, pipeline: PipelineRequest) -> PipelineResponse:
"""Creates a new pipeline in a workspace.
Args:
pipeline: The pipeline to create.
Returns:
The newly created pipeline.
"""
return self._create_workspace_scoped_resource(
resource=pipeline,
route=PIPELINES,
response_model=PipelineResponse,
)
def get_pipeline(
self, pipeline_id: UUID, hydrate: bool = True
) -> PipelineResponse:
"""Get a pipeline with a given ID.
Args:
pipeline_id: ID of the pipeline.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The pipeline.
"""
return self._get_resource(
resource_id=pipeline_id,
route=PIPELINES,
response_model=PipelineResponse,
params={"hydrate": hydrate},
)
def list_pipelines(
self,
pipeline_filter_model: PipelineFilter,
hydrate: bool = False,
) -> Page[PipelineResponse]:
"""List all pipelines matching the given filter criteria.
Args:
pipeline_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all pipelines matching the filter criteria.
"""
return self._list_paginated_resources(
route=PIPELINES,
response_model=PipelineResponse,
filter_model=pipeline_filter_model,
params={"hydrate": hydrate},
)
def update_pipeline(
self, pipeline_id: UUID, pipeline_update: PipelineUpdate
) -> PipelineResponse:
"""Updates a pipeline.
Args:
pipeline_id: The ID of the pipeline to be updated.
pipeline_update: The update to be applied.
Returns:
The updated pipeline.
"""
return self._update_resource(
resource_id=pipeline_id,
resource_update=pipeline_update,
route=PIPELINES,
response_model=PipelineResponse,
)
def delete_pipeline(self, pipeline_id: UUID) -> None:
"""Deletes a pipeline.
Args:
pipeline_id: The ID of the pipeline to delete.
"""
self._delete_resource(
resource_id=pipeline_id,
route=PIPELINES,
)
# --------------------------- Pipeline Builds ---------------------------
def create_build(
self,
build: PipelineBuildRequest,
) -> PipelineBuildResponse:
"""Creates a new build in a workspace.
Args:
build: The build to create.
Returns:
The newly created build.
"""
return self._create_workspace_scoped_resource(
resource=build,
route=PIPELINE_BUILDS,
response_model=PipelineBuildResponse,
)
def get_build(
self, build_id: UUID, hydrate: bool = True
) -> PipelineBuildResponse:
"""Get a build with a given ID.
Args:
build_id: ID of the build.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The build.
"""
return self._get_resource(
resource_id=build_id,
route=PIPELINE_BUILDS,
response_model=PipelineBuildResponse,
params={"hydrate": hydrate},
)
def list_builds(
self,
build_filter_model: PipelineBuildFilter,
hydrate: bool = False,
) -> Page[PipelineBuildResponse]:
"""List all builds matching the given filter criteria.
Args:
build_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all builds matching the filter criteria.
"""
return self._list_paginated_resources(
route=PIPELINE_BUILDS,
response_model=PipelineBuildResponse,
filter_model=build_filter_model,
params={"hydrate": hydrate},
)
def delete_build(self, build_id: UUID) -> None:
"""Deletes a build.
Args:
build_id: The ID of the build to delete.
"""
self._delete_resource(
resource_id=build_id,
route=PIPELINE_BUILDS,
)
# -------------------------- Pipeline Deployments --------------------------
def create_deployment(
self,
deployment: PipelineDeploymentRequest,
) -> PipelineDeploymentResponse:
"""Creates a new deployment in a workspace.
Args:
deployment: The deployment to create.
Returns:
The newly created deployment.
"""
return self._create_workspace_scoped_resource(
resource=deployment,
route=PIPELINE_DEPLOYMENTS,
response_model=PipelineDeploymentResponse,
)
def get_deployment(
self, deployment_id: UUID, hydrate: bool = True
) -> PipelineDeploymentResponse:
"""Get a deployment with a given ID.
Args:
deployment_id: ID of the deployment.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The deployment.
"""
return self._get_resource(
resource_id=deployment_id,
route=PIPELINE_DEPLOYMENTS,
response_model=PipelineDeploymentResponse,
params={"hydrate": hydrate},
)
def list_deployments(
self,
deployment_filter_model: PipelineDeploymentFilter,
hydrate: bool = False,
) -> Page[PipelineDeploymentResponse]:
"""List all deployments matching the given filter criteria.
Args:
deployment_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all deployments matching the filter criteria.
"""
return self._list_paginated_resources(
route=PIPELINE_DEPLOYMENTS,
response_model=PipelineDeploymentResponse,
filter_model=deployment_filter_model,
params={"hydrate": hydrate},
)
def delete_deployment(self, deployment_id: UUID) -> None:
"""Deletes a deployment.
Args:
deployment_id: The ID of the deployment to delete.
"""
self._delete_resource(
resource_id=deployment_id,
route=PIPELINE_DEPLOYMENTS,
)
# -------------------- Run templates --------------------
def create_run_template(
self,
template: RunTemplateRequest,
) -> RunTemplateResponse:
"""Create a new run template.
Args:
template: The template to create.
Returns:
The newly created template.
"""
return self._create_workspace_scoped_resource(
resource=template,
route=RUN_TEMPLATES,
response_model=RunTemplateResponse,
)
def get_run_template(
self, template_id: UUID, hydrate: bool = True
) -> RunTemplateResponse:
"""Get a run template with a given ID.
Args:
template_id: ID of the template.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The template.
"""
return self._get_resource(
resource_id=template_id,
route=RUN_TEMPLATES,
response_model=RunTemplateResponse,
params={"hydrate": hydrate},
)
def list_run_templates(
self,
template_filter_model: RunTemplateFilter,
hydrate: bool = False,
) -> Page[RunTemplateResponse]:
"""List all run templates matching the given filter criteria.
Args:
template_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all templates matching the filter criteria.
"""
return self._list_paginated_resources(
route=RUN_TEMPLATES,
response_model=RunTemplateResponse,
filter_model=template_filter_model,
params={"hydrate": hydrate},
)
def update_run_template(
self,
template_id: UUID,
template_update: RunTemplateUpdate,
) -> RunTemplateResponse:
"""Updates a run template.
Args:
template_id: The ID of the template to update.
template_update: The update to apply.
Returns:
The updated template.
"""
return self._update_resource(
resource_id=template_id,
resource_update=template_update,
route=RUN_TEMPLATES,
response_model=RunTemplateResponse,
)
def delete_run_template(self, template_id: UUID) -> None:
"""Delete a run template.
Args:
template_id: The ID of the template to delete.
"""
self._delete_resource(
resource_id=template_id,
route=RUN_TEMPLATES,
)
def run_template(
self,
template_id: UUID,
run_configuration: Optional[PipelineRunConfiguration] = None,
) -> PipelineRunResponse:
"""Run a template.
Args:
template_id: The ID of the template to run.
run_configuration: Configuration for the run.
Raises:
RuntimeError: If the server does not support running a template.
Returns:
Model of the pipeline run.
"""
run_configuration = run_configuration or PipelineRunConfiguration()
try:
response_body = self.post(
f"{RUN_TEMPLATES}/{template_id}/runs",
body=run_configuration,
)
except MethodNotAllowedError as e:
raise RuntimeError(
"Running a template is not supported for this server."
) from e
return PipelineRunResponse.model_validate(response_body)
# -------------------- Event Sources --------------------
def create_event_source(
self, event_source: EventSourceRequest
) -> EventSourceResponse:
"""Create an event_source.
Args:
event_source: The event_source to create.
Returns:
The created event_source.
"""
return self._create_resource(
resource=event_source,
route=EVENT_SOURCES,
response_model=EventSourceResponse,
)
def get_event_source(
self,
event_source_id: UUID,
hydrate: bool = True,
) -> EventSourceResponse:
"""Get an event_source by ID.
Args:
event_source_id: The ID of the event_source to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The event_source.
"""
return self._get_resource(
resource_id=event_source_id,
route=EVENT_SOURCES,
response_model=EventSourceResponse,
params={"hydrate": hydrate},
)
def list_event_sources(
self,
event_source_filter_model: EventSourceFilter,
hydrate: bool = False,
) -> Page[EventSourceResponse]:
"""List all event_sources matching the given filter criteria.
Args:
event_source_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all event_sources matching the filter criteria.
"""
return self._list_paginated_resources(
route=EVENT_SOURCES,
response_model=EventSourceResponse,
filter_model=event_source_filter_model,
params={"hydrate": hydrate},
)
def update_event_source(
self,
event_source_id: UUID,
event_source_update: EventSourceUpdate,
) -> EventSourceResponse:
"""Update an existing event_source.
Args:
event_source_id: The ID of the event_source to update.
event_source_update: The update to be applied to the event_source.
Returns:
The updated event_source.
"""
return self._update_resource(
resource_id=event_source_id,
resource_update=event_source_update,
route=EVENT_SOURCES,
response_model=EventSourceResponse,
)
def delete_event_source(self, event_source_id: UUID) -> None:
"""Delete an event_source.
Args:
event_source_id: The ID of the event_source to delete.
"""
self._delete_resource(
resource_id=event_source_id,
route=EVENT_SOURCES,
)
# ----------------------------- Pipeline runs -----------------------------
def create_run(
self, pipeline_run: PipelineRunRequest
) -> PipelineRunResponse:
"""Creates a pipeline run.
Args:
pipeline_run: The pipeline run to create.
Returns:
The created pipeline run.
"""
return self._create_workspace_scoped_resource(
resource=pipeline_run,
response_model=PipelineRunResponse,
route=RUNS,
)
def get_run(
self, run_name_or_id: Union[UUID, str], hydrate: bool = True
) -> PipelineRunResponse:
"""Gets a pipeline run.
Args:
run_name_or_id: The name or ID of the pipeline run to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The pipeline run.
"""
return self._get_resource(
resource_id=run_name_or_id,
route=RUNS,
response_model=PipelineRunResponse,
params={"hydrate": hydrate},
)
def list_runs(
self,
runs_filter_model: PipelineRunFilter,
hydrate: bool = False,
) -> Page[PipelineRunResponse]:
"""List all pipeline runs matching the given filter criteria.
Args:
runs_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all pipeline runs matching the filter criteria.
"""
return self._list_paginated_resources(
route=RUNS,
response_model=PipelineRunResponse,
filter_model=runs_filter_model,
params={"hydrate": hydrate},
)
def update_run(
self, run_id: UUID, run_update: PipelineRunUpdate
) -> PipelineRunResponse:
"""Updates a pipeline run.
Args:
run_id: The ID of the pipeline run to update.
run_update: The update to be applied to the pipeline run.
Returns:
The updated pipeline run.
"""
return self._update_resource(
resource_id=run_id,
resource_update=run_update,
response_model=PipelineRunResponse,
route=RUNS,
)
def delete_run(self, run_id: UUID) -> None:
"""Deletes a pipeline run.
Args:
run_id: The ID of the pipeline run to delete.
"""
self._delete_resource(
resource_id=run_id,
route=RUNS,
)
def get_or_create_run(
self, pipeline_run: PipelineRunRequest
) -> Tuple[PipelineRunResponse, bool]:
"""Gets or creates a pipeline run.
If a run with the same ID or name already exists, it is returned.
Otherwise, a new run is created.
Args:
pipeline_run: The pipeline run to get or create.
Returns:
The pipeline run, and a boolean indicating whether the run was
created or not.
"""
return self._get_or_create_workspace_scoped_resource(
resource=pipeline_run,
route=RUNS,
response_model=PipelineRunResponse,
)
# ----------------------------- Run Metadata -----------------------------
def create_run_metadata(
self, run_metadata: RunMetadataRequest
) -> List[RunMetadataResponse]:
"""Creates run metadata.
Args:
run_metadata: The run metadata to create.
Returns:
The created run metadata.
"""
route = f"{WORKSPACES}/{str(run_metadata.workspace)}{RUN_METADATA}"
response_body = self.post(f"{route}", body=run_metadata)
result: List[RunMetadataResponse] = []
if isinstance(response_body, list):
for metadata in response_body or []:
result.append(RunMetadataResponse.model_validate(metadata))
return result
def get_run_metadata(
self, run_metadata_id: UUID, hydrate: bool = True
) -> RunMetadataResponse:
"""Gets run metadata with the given ID.
Args:
run_metadata_id: The ID of the run metadata to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The run metadata.
"""
return self._get_resource(
resource_id=run_metadata_id,
route=RUN_METADATA,
response_model=RunMetadataResponse,
params={"hydrate": hydrate},
)
def list_run_metadata(
self,
run_metadata_filter_model: RunMetadataFilter,
hydrate: bool = False,
) -> Page[RunMetadataResponse]:
"""List run metadata.
Args:
run_metadata_filter_model: All filter parameters including
pagination params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The run metadata.
"""
return self._list_paginated_resources(
route=RUN_METADATA,
response_model=RunMetadataResponse,
filter_model=run_metadata_filter_model,
params={"hydrate": hydrate},
)
# ----------------------------- Schedules -----------------------------
def create_schedule(self, schedule: ScheduleRequest) -> ScheduleResponse:
"""Creates a new schedule.
Args:
schedule: The schedule to create.
Returns:
The newly created schedule.
"""
return self._create_workspace_scoped_resource(
resource=schedule,
route=SCHEDULES,
response_model=ScheduleResponse,
)
def get_schedule(
self, schedule_id: UUID, hydrate: bool = True
) -> ScheduleResponse:
"""Get a schedule with a given ID.
Args:
schedule_id: ID of the schedule.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The schedule.
"""
return self._get_resource(
resource_id=schedule_id,
route=SCHEDULES,
response_model=ScheduleResponse,
params={"hydrate": hydrate},
)
def list_schedules(
self,
schedule_filter_model: ScheduleFilter,
hydrate: bool = False,
) -> Page[ScheduleResponse]:
"""List all schedules in the workspace.
Args:
schedule_filter_model: All filter parameters including pagination
params
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of schedules.
"""
return self._list_paginated_resources(
route=SCHEDULES,
response_model=ScheduleResponse,
filter_model=schedule_filter_model,
params={"hydrate": hydrate},
)
def update_schedule(
self,
schedule_id: UUID,
schedule_update: ScheduleUpdate,
) -> ScheduleResponse:
"""Updates a schedule.
Args:
schedule_id: The ID of the schedule to be updated.
schedule_update: The update to be applied.
Returns:
The updated schedule.
"""
return self._update_resource(
resource_id=schedule_id,
resource_update=schedule_update,
route=SCHEDULES,
response_model=ScheduleResponse,
)
def delete_schedule(self, schedule_id: UUID) -> None:
"""Deletes a schedule.
Args:
schedule_id: The ID of the schedule to delete.
"""
self._delete_resource(
resource_id=schedule_id,
route=SCHEDULES,
)
# --------------------------- Secrets ---------------------------
def create_secret(self, secret: SecretRequest) -> SecretResponse:
"""Creates a new secret.
The new secret is also validated against the scoping rules enforced in
the secrets store:
- only one workspace-scoped secret with the given name can exist
in the target workspace.
- only one user-scoped secret with the given name can exist in the
target workspace for the target user.
Args:
secret: The secret to create.
Returns:
The newly created secret.
"""
return self._create_workspace_scoped_resource(
resource=secret,
route=SECRETS,
response_model=SecretResponse,
)
def get_secret(
self, secret_id: UUID, hydrate: bool = True
) -> SecretResponse:
"""Get a secret by ID.
Args:
secret_id: The ID of the secret to fetch.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The secret.
"""
return self._get_resource(
resource_id=secret_id,
route=SECRETS,
response_model=SecretResponse,
params={"hydrate": hydrate},
)
def list_secrets(
self, secret_filter_model: SecretFilter, hydrate: bool = False
) -> Page[SecretResponse]:
"""List all secrets matching the given filter criteria.
Note that returned secrets do not include any secret values. To fetch
the secret values, use `get_secret`.
Args:
secret_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all secrets matching the filter criteria, with pagination
information and sorted according to the filter criteria. The
returned secrets do not include any secret values, only metadata. To
fetch the secret values, use `get_secret` individually with each
secret.
"""
return self._list_paginated_resources(
route=SECRETS,
response_model=SecretResponse,
filter_model=secret_filter_model,
params={"hydrate": hydrate},
)
def update_secret(
self, secret_id: UUID, secret_update: SecretUpdate
) -> SecretResponse:
"""Updates a secret.
Secret values that are specified as `None` in the update that are
present in the existing secret are removed from the existing secret.
Values that are present in both secrets are overwritten. All other
values in both the existing secret and the update are kept (merged).
If the update includes a change of name or scope, the scoping rules
enforced in the secrets store are used to validate the update:
- only one workspace-scoped secret with the given name can exist
in the target workspace.
- only one user-scoped secret with the given name can exist in the
target workspace for the target user.
Args:
secret_id: The ID of the secret to be updated.
secret_update: The update to be applied.
Returns:
The updated secret.
"""
return self._update_resource(
resource_id=secret_id,
resource_update=secret_update,
route=SECRETS,
response_model=SecretResponse,
# The default endpoint behavior is to replace all secret values
# with the values in the update. We want to merge the values
# instead.
params=dict(patch_values=True),
)
def delete_secret(self, secret_id: UUID) -> None:
"""Delete a secret.
Args:
secret_id: The id of the secret to delete.
"""
self._delete_resource(
resource_id=secret_id,
route=SECRETS,
)
def backup_secrets(
self, ignore_errors: bool = True, delete_secrets: bool = False
) -> None:
"""Backs up all secrets to the configured backup secrets store.
Args:
ignore_errors: Whether to ignore individual errors during the backup
process and attempt to backup all secrets.
delete_secrets: Whether to delete the secrets that have been
successfully backed up from the primary secrets store. Setting
this flag effectively moves all secrets from the primary secrets
store to the backup secrets store.
"""
params: Dict[str, Any] = {
"ignore_errors": ignore_errors,
"delete_secrets": delete_secrets,
}
self.put(
f"{SECRETS_OPERATIONS}{SECRETS_BACKUP}",
params=params,
)
def restore_secrets(
self, ignore_errors: bool = False, delete_secrets: bool = False
) -> None:
"""Restore all secrets from the configured backup secrets store.
Args:
ignore_errors: Whether to ignore individual errors during the
restore process and attempt to restore all secrets.
delete_secrets: Whether to delete the secrets that have been
successfully restored from the backup secrets store. Setting
this flag effectively moves all secrets from the backup secrets
store to the primary secrets store.
"""
params: Dict[str, Any] = {
"ignore_errors": ignore_errors,
"delete_secrets": delete_secrets,
}
self.put(
f"{SECRETS_OPERATIONS}{SECRETS_RESTORE}",
params=params,
)
# --------------------------- Service Accounts ---------------------------
def create_service_account(
self, service_account: ServiceAccountRequest
) -> ServiceAccountResponse:
"""Creates a new service account.
Args:
service_account: Service account to be created.
Returns:
The newly created service account.
"""
return self._create_resource(
resource=service_account,
route=SERVICE_ACCOUNTS,
response_model=ServiceAccountResponse,
)
def get_service_account(
self,
service_account_name_or_id: Union[str, UUID],
hydrate: bool = True,
) -> ServiceAccountResponse:
"""Gets a specific service account.
Args:
service_account_name_or_id: The name or ID of the service account to
get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested service account, if it was found.
"""
return self._get_resource(
resource_id=service_account_name_or_id,
route=SERVICE_ACCOUNTS,
response_model=ServiceAccountResponse,
params={"hydrate": hydrate},
)
def list_service_accounts(
self, filter_model: ServiceAccountFilter, hydrate: bool = False
) -> Page[ServiceAccountResponse]:
"""List all service accounts.
Args:
filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of filtered service accounts.
"""
return self._list_paginated_resources(
route=SERVICE_ACCOUNTS,
response_model=ServiceAccountResponse,
filter_model=filter_model,
params={"hydrate": hydrate},
)
def update_service_account(
self,
service_account_name_or_id: Union[str, UUID],
service_account_update: ServiceAccountUpdate,
) -> ServiceAccountResponse:
"""Updates an existing service account.
Args:
service_account_name_or_id: The name or the ID of the service
account to update.
service_account_update: The update to be applied to the service
account.
Returns:
The updated service account.
"""
return self._update_resource(
resource_id=service_account_name_or_id,
resource_update=service_account_update,
route=SERVICE_ACCOUNTS,
response_model=ServiceAccountResponse,
)
def delete_service_account(
self,
service_account_name_or_id: Union[str, UUID],
) -> None:
"""Delete a service account.
Args:
service_account_name_or_id: The name or the ID of the service
account to delete.
"""
self._delete_resource(
resource_id=service_account_name_or_id,
route=SERVICE_ACCOUNTS,
)
# --------------------------- Service Connectors ---------------------------
def create_service_connector(
self, service_connector: ServiceConnectorRequest
) -> ServiceConnectorResponse:
"""Creates a new service connector.
Args:
service_connector: Service connector to be created.
Returns:
The newly created service connector.
"""
connector_model = self._create_workspace_scoped_resource(
resource=service_connector,
route=SERVICE_CONNECTORS,
response_model=ServiceConnectorResponse,
)
self._populate_connector_type(connector_model)
return connector_model
def get_service_connector(
self, service_connector_id: UUID, hydrate: bool = True
) -> ServiceConnectorResponse:
"""Gets a specific service connector.
Args:
service_connector_id: The ID of the service connector to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested service connector, if it was found.
"""
connector_model = self._get_resource(
resource_id=service_connector_id,
route=SERVICE_CONNECTORS,
response_model=ServiceConnectorResponse,
params={"expand_secrets": False, "hydrate": hydrate},
)
self._populate_connector_type(connector_model)
return connector_model
def list_service_connectors(
self,
filter_model: ServiceConnectorFilter,
hydrate: bool = False,
) -> Page[ServiceConnectorResponse]:
"""List all service connectors.
Args:
filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all service connectors.
"""
connector_models = self._list_paginated_resources(
route=SERVICE_CONNECTORS,
response_model=ServiceConnectorResponse,
filter_model=filter_model,
params={"expand_secrets": False, "hydrate": hydrate},
)
self._populate_connector_type(*connector_models.items)
return connector_models
def update_service_connector(
self, service_connector_id: UUID, update: ServiceConnectorUpdate
) -> ServiceConnectorResponse:
"""Updates an existing service connector.
The update model contains the fields to be updated. If a field value is
set to None in the model, the field is not updated, but there are
special rules concerning some fields:
* the `configuration` and `secrets` fields together represent a full
valid configuration update, not just a partial update. If either is
set (i.e. not None) in the update, their values are merged together and
will replace the existing configuration and secrets values.
* the `resource_id` field value is also a full replacement value: if set
to `None`, the resource ID is removed from the service connector.
* the `expiration_seconds` field value is also a full replacement value:
if set to `None`, the expiration is removed from the service connector.
* the `secret_id` field value in the update is ignored, given that
secrets are managed internally by the ZenML store.
* the `labels` field is also a full labels update: if set (i.e. not
`None`), all existing labels are removed and replaced by the new labels
in the update.
Args:
service_connector_id: The ID of the service connector to update.
update: The update to be applied to the service connector.
Returns:
The updated service connector.
"""
connector_model = self._update_resource(
resource_id=service_connector_id,
resource_update=update,
response_model=ServiceConnectorResponse,
route=SERVICE_CONNECTORS,
)
self._populate_connector_type(connector_model)
return connector_model
def delete_service_connector(self, service_connector_id: UUID) -> None:
"""Deletes a service connector.
Args:
service_connector_id: The ID of the service connector to delete.
"""
self._delete_resource(
resource_id=service_connector_id, route=SERVICE_CONNECTORS
)
def _populate_connector_type(
self,
*connector_models: Union[
ServiceConnectorResponse, ServiceConnectorResourcesModel
],
) -> None:
"""Populates or updates the connector type of the given connector or resource models.
If the connector type is not locally available, the connector type
field is left as is. The local and remote flags of the connector type
are updated accordingly.
Args:
connector_models: The service connector or resource models to
populate.
"""
for service_connector in connector_models:
# Mark the remote connector type as being only remotely available
if not isinstance(service_connector.connector_type, str):
service_connector.connector_type.local = False
service_connector.connector_type.remote = True
if not service_connector_registry.is_registered(
service_connector.type
):
continue
connector_type = (
service_connector_registry.get_service_connector_type(
service_connector.type
)
)
connector_type.local = True
if not isinstance(service_connector.connector_type, str):
connector_type.remote = True
# TODO: Normally, this could have been handled with setter
# functions over the connector type property in the response
# model. However, pydantic breaks property setter functions.
# We can find a more elegant solution here.
if isinstance(service_connector, ServiceConnectorResponse):
service_connector.set_connector_type(connector_type)
elif isinstance(service_connector, ServiceConnectorResourcesModel):
service_connector.connector_type = connector_type
else:
TypeError(
"The service connector must be an instance of either"
"`ServiceConnectorResponse` or "
"`ServiceConnectorResourcesModel`."
)
def verify_service_connector_config(
self,
service_connector: ServiceConnectorRequest,
list_resources: bool = True,
) -> ServiceConnectorResourcesModel:
"""Verifies if a service connector configuration has access to resources.
Args:
service_connector: The service connector configuration to verify.
list_resources: If True, the list of all resources accessible
through the service connector and matching the supplied resource
type and ID are returned.
Returns:
The list of resources that the service connector configuration has
access to.
"""
response_body = self.post(
f"{SERVICE_CONNECTORS}{SERVICE_CONNECTOR_VERIFY}",
body=service_connector,
params={"list_resources": list_resources},
timeout=max(
self.config.http_timeout,
SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT,
),
)
resources = ServiceConnectorResourcesModel.model_validate(
response_body
)
self._populate_connector_type(resources)
return resources
def verify_service_connector(
self,
service_connector_id: UUID,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
list_resources: bool = True,
) -> ServiceConnectorResourcesModel:
"""Verifies if a service connector instance has access to one or more resources.
Args:
service_connector_id: The ID of the service connector to verify.
resource_type: The type of resource to verify access to.
resource_id: The ID of the resource to verify access to.
list_resources: If True, the list of all resources accessible
through the service connector and matching the supplied resource
type and ID are returned.
Returns:
The list of resources that the service connector has access to,
scoped to the supplied resource type and ID, if provided.
"""
params: Dict[str, Any] = {"list_resources": list_resources}
if resource_type:
params["resource_type"] = resource_type
if resource_id:
params["resource_id"] = resource_id
response_body = self.put(
f"{SERVICE_CONNECTORS}/{str(service_connector_id)}{SERVICE_CONNECTOR_VERIFY}",
params=params,
timeout=max(
self.config.http_timeout,
SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT,
),
)
resources = ServiceConnectorResourcesModel.model_validate(
response_body
)
self._populate_connector_type(resources)
return resources
def get_service_connector_client(
self,
service_connector_id: UUID,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> ServiceConnectorResponse:
"""Get a service connector client for a service connector and given resource.
Args:
service_connector_id: The ID of the base service connector to use.
resource_type: The type of resource to get a client for.
resource_id: The ID of the resource to get a client for.
Returns:
A service connector client that can be used to access the given
resource.
"""
params = {}
if resource_type:
params["resource_type"] = resource_type
if resource_id:
params["resource_id"] = resource_id
response_body = self.get(
f"{SERVICE_CONNECTORS}/{str(service_connector_id)}{SERVICE_CONNECTOR_CLIENT}",
params=params,
)
connector = ServiceConnectorResponse.model_validate(response_body)
self._populate_connector_type(connector)
return connector
def list_service_connector_resources(
self,
workspace_name_or_id: Union[str, UUID],
connector_type: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
) -> List[ServiceConnectorResourcesModel]:
"""List resources that can be accessed by service connectors.
Args:
workspace_name_or_id: The name or ID of the workspace to scope to.
connector_type: The type of service connector to scope to.
resource_type: The type of resource to scope to.
resource_id: The ID of the resource to scope to.
Returns:
The matching list of resources that available service
connectors have access to.
"""
params = {}
if connector_type:
params["connector_type"] = connector_type
if resource_type:
params["resource_type"] = resource_type
if resource_id:
params["resource_id"] = resource_id
response_body = self.get(
f"{WORKSPACES}/{workspace_name_or_id}{SERVICE_CONNECTORS}{SERVICE_CONNECTOR_RESOURCES}",
params=params,
timeout=max(
self.config.http_timeout,
SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT,
),
)
assert isinstance(response_body, list)
resource_list = [
ServiceConnectorResourcesModel.model_validate(item)
for item in response_body
]
self._populate_connector_type(*resource_list)
# For service connectors with types that are only locally available,
# we need to retrieve the resource list locally
for idx, resources in enumerate(resource_list):
if isinstance(resources.connector_type, str):
# Skip connector types that are neither locally nor remotely
# available
continue
if resources.connector_type.remote:
# Skip connector types that are remotely available
continue
# Retrieve the resource list locally
assert resources.id is not None
connector = self.get_service_connector(resources.id)
connector_instance = (
service_connector_registry.instantiate_connector(
model=connector
)
)
try:
local_resources = connector_instance.verify(
resource_type=resource_type,
resource_id=resource_id,
)
except (ValueError, AuthorizationException) as e:
logger.error(
f'Failed to fetch {resource_type or "available"} '
f"resources from service connector {connector.name}/"
f"{connector.id}: {e}"
)
continue
resource_list[idx] = local_resources
return resource_list
def list_service_connector_types(
self,
connector_type: Optional[str] = None,
resource_type: Optional[str] = None,
auth_method: Optional[str] = None,
) -> List[ServiceConnectorTypeModel]:
"""Get a list of service connector types.
Args:
connector_type: Filter by connector type.
resource_type: Filter by resource type.
auth_method: Filter by authentication method.
Returns:
List of service connector types.
"""
params = {}
if connector_type:
params["connector_type"] = connector_type
if resource_type:
params["resource_type"] = resource_type
if auth_method:
params["auth_method"] = auth_method
response_body = self.get(
SERVICE_CONNECTOR_TYPES,
params=params,
)
assert isinstance(response_body, list)
remote_connector_types = [
ServiceConnectorTypeModel.model_validate(item)
for item in response_body
]
# Mark the remote connector types as being only remotely available
for c in remote_connector_types:
c.local = False
c.remote = True
local_connector_types = (
service_connector_registry.list_service_connector_types(
connector_type=connector_type,
resource_type=resource_type,
auth_method=auth_method,
)
)
# Add the connector types in the local registry to the list of
# connector types available remotely. Overwrite those that have
# the same connector type but mark them as being remotely available.
connector_types_map = {
connector_type.connector_type: connector_type
for connector_type in remote_connector_types
}
for connector in local_connector_types:
if connector.connector_type in connector_types_map:
connector.remote = True
connector_types_map[connector.connector_type] = connector
return list(connector_types_map.values())
def get_service_connector_type(
self,
connector_type: str,
) -> ServiceConnectorTypeModel:
"""Returns the requested service connector type.
Args:
connector_type: the service connector type identifier.
Returns:
The requested service connector type.
"""
# Use the local registry to get the service connector type, if it
# exists.
local_connector_type: Optional[ServiceConnectorTypeModel] = None
if service_connector_registry.is_registered(connector_type):
local_connector_type = (
service_connector_registry.get_service_connector_type(
connector_type
)
)
try:
response_body = self.get(
f"{SERVICE_CONNECTOR_TYPES}/{connector_type}",
)
remote_connector_type = ServiceConnectorTypeModel.model_validate(
response_body
)
if local_connector_type:
# If locally available, return the local connector type but
# mark it as being remotely available.
local_connector_type.remote = True
return local_connector_type
# Mark the remote connector type as being only remotely available
remote_connector_type.local = False
remote_connector_type.remote = True
return remote_connector_type
except KeyError:
# If the service connector type is not found, check the local
# registry.
return service_connector_registry.get_service_connector_type(
connector_type
)
# ----------------------------- Stacks -----------------------------
def create_stack(self, stack: StackRequest) -> StackResponse:
"""Register a new stack.
Args:
stack: The stack to register.
Returns:
The registered stack.
"""
assert stack.workspace is not None
return self._create_resource(
resource=stack,
response_model=StackResponse,
route=f"{WORKSPACES}/{str(stack.workspace)}{STACKS}",
)
def get_stack(self, stack_id: UUID, hydrate: bool = True) -> StackResponse:
"""Get a stack by its unique ID.
Args:
stack_id: The ID of the stack to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The stack with the given ID.
"""
return self._get_resource(
resource_id=stack_id,
route=STACKS,
response_model=StackResponse,
params={"hydrate": hydrate},
)
def list_stacks(
self, stack_filter_model: StackFilter, hydrate: bool = False
) -> Page[StackResponse]:
"""List all stacks matching the given filter criteria.
Args:
stack_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all stacks matching the filter criteria.
"""
return self._list_paginated_resources(
route=STACKS,
response_model=StackResponse,
filter_model=stack_filter_model,
params={"hydrate": hydrate},
)
def update_stack(
self, stack_id: UUID, stack_update: StackUpdate
) -> StackResponse:
"""Update a stack.
Args:
stack_id: The ID of the stack update.
stack_update: The update request on the stack.
Returns:
The updated stack.
"""
return self._update_resource(
resource_id=stack_id,
resource_update=stack_update,
route=STACKS,
response_model=StackResponse,
)
def delete_stack(self, stack_id: UUID) -> None:
"""Delete a stack.
Args:
stack_id: The ID of the stack to delete.
"""
self._delete_resource(
resource_id=stack_id,
route=STACKS,
)
# ---------------- Stack deployments-----------------
def get_stack_deployment_info(
self,
provider: StackDeploymentProvider,
) -> StackDeploymentInfo:
"""Get information about a stack deployment provider.
Args:
provider: The stack deployment provider.
Returns:
Information about the stack deployment provider.
"""
body = self.get(
f"{STACK_DEPLOYMENT}{INFO}",
params={"provider": provider.value},
)
return StackDeploymentInfo.model_validate(body)
def get_stack_deployment_config(
self,
provider: StackDeploymentProvider,
stack_name: str,
location: Optional[str] = None,
) -> StackDeploymentConfig:
"""Return the cloud provider console URL and configuration needed to deploy the ZenML stack.
Args:
provider: The stack deployment provider.
stack_name: The name of the stack.
location: The location where the stack should be deployed.
Returns:
The cloud provider console URL and configuration needed to deploy
the ZenML stack to the specified cloud provider.
"""
params = {
"provider": provider.value,
"stack_name": stack_name,
}
if location:
params["location"] = location
body = self.get(f"{STACK_DEPLOYMENT}{CONFIG}", params=params)
return StackDeploymentConfig.model_validate(body)
def get_stack_deployment_stack(
self,
provider: StackDeploymentProvider,
stack_name: str,
location: Optional[str] = None,
date_start: Optional[datetime] = None,
) -> Optional[DeployedStack]:
"""Return a matching ZenML stack that was deployed and registered.
Args:
provider: The stack deployment provider.
stack_name: The name of the stack.
location: The location where the stack should be deployed.
date_start: The date when the deployment started.
Returns:
The ZenML stack that was deployed and registered or None if the
stack was not found.
"""
params = {
"provider": provider.value,
"stack_name": stack_name,
}
if location:
params["location"] = location
if date_start:
params["date_start"] = str(date_start)
body = self.get(
f"{STACK_DEPLOYMENT}{STACK}",
params=params,
)
if body:
return DeployedStack.model_validate(body)
return None
# ----------------------------- Step runs -----------------------------
def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
"""Creates a step run.
Args:
step_run: The step run to create.
Returns:
The created step run.
"""
return self._create_resource(
resource=step_run,
response_model=StepRunResponse,
route=STEPS,
)
def get_run_step(
self, step_run_id: UUID, hydrate: bool = True
) -> StepRunResponse:
"""Get a step run by ID.
Args:
step_run_id: The ID of the step run to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The step run.
"""
return self._get_resource(
resource_id=step_run_id,
route=STEPS,
response_model=StepRunResponse,
params={"hydrate": hydrate},
)
def list_run_steps(
self,
step_run_filter_model: StepRunFilter,
hydrate: bool = False,
) -> Page[StepRunResponse]:
"""List all step runs matching the given filter criteria.
Args:
step_run_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all step runs matching the filter criteria.
"""
return self._list_paginated_resources(
route=STEPS,
response_model=StepRunResponse,
filter_model=step_run_filter_model,
params={"hydrate": hydrate},
)
def update_run_step(
self,
step_run_id: UUID,
step_run_update: StepRunUpdate,
) -> StepRunResponse:
"""Updates a step run.
Args:
step_run_id: The ID of the step to update.
step_run_update: The update to be applied to the step.
Returns:
The updated step run.
"""
return self._update_resource(
resource_id=step_run_id,
resource_update=step_run_update,
response_model=StepRunResponse,
route=STEPS,
)
# -------------------- Triggers --------------------
def create_trigger(self, trigger: TriggerRequest) -> TriggerResponse:
"""Create an trigger.
Args:
trigger: The trigger to create.
Returns:
The created trigger.
"""
return self._create_resource(
resource=trigger,
route=TRIGGERS,
response_model=TriggerResponse,
)
def get_trigger(
self,
trigger_id: UUID,
hydrate: bool = True,
) -> TriggerResponse:
"""Get a trigger by ID.
Args:
trigger_id: The ID of the trigger to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The trigger.
"""
return self._get_resource(
resource_id=trigger_id,
route=TRIGGERS,
response_model=TriggerResponse,
params={"hydrate": hydrate},
)
def list_triggers(
self,
trigger_filter_model: TriggerFilter,
hydrate: bool = False,
) -> Page[TriggerResponse]:
"""List all triggers matching the given filter criteria.
Args:
trigger_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all triggers matching the filter criteria.
"""
return self._list_paginated_resources(
route=TRIGGERS,
response_model=TriggerResponse,
filter_model=trigger_filter_model,
params={"hydrate": hydrate},
)
def update_trigger(
self,
trigger_id: UUID,
trigger_update: TriggerUpdate,
) -> TriggerResponse:
"""Update an existing trigger.
Args:
trigger_id: The ID of the trigger to update.
trigger_update: The update to be applied to the trigger.
Returns:
The updated trigger.
"""
return self._update_resource(
resource_id=trigger_id,
resource_update=trigger_update,
route=TRIGGERS,
response_model=TriggerResponse,
)
def delete_trigger(self, trigger_id: UUID) -> None:
"""Delete an trigger.
Args:
trigger_id: The ID of the trigger to delete.
"""
self._delete_resource(
resource_id=trigger_id,
route=TRIGGERS,
)
# -------------------- Trigger Executions --------------------
def get_trigger_execution(
self,
trigger_execution_id: UUID,
hydrate: bool = True,
) -> TriggerExecutionResponse:
"""Get an trigger execution by ID.
Args:
trigger_execution_id: The ID of the trigger execution to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The trigger execution.
"""
return self._get_resource(
resource_id=trigger_execution_id,
route=TRIGGER_EXECUTIONS,
response_model=TriggerExecutionResponse,
params={"hydrate": hydrate},
)
def list_trigger_executions(
self,
trigger_execution_filter_model: TriggerExecutionFilter,
hydrate: bool = False,
) -> Page[TriggerExecutionResponse]:
"""List all trigger executions matching the given filter criteria.
Args:
trigger_execution_filter_model: All filter parameters including
pagination params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all trigger executions matching the filter criteria.
"""
return self._list_paginated_resources(
route=TRIGGER_EXECUTIONS,
response_model=TriggerExecutionResponse,
filter_model=trigger_execution_filter_model,
params={"hydrate": hydrate},
)
def delete_trigger_execution(self, trigger_execution_id: UUID) -> None:
"""Delete a trigger execution.
Args:
trigger_execution_id: The ID of the trigger execution to delete.
"""
self._delete_resource(
resource_id=trigger_execution_id,
route=TRIGGER_EXECUTIONS,
)
# ----------------------------- Users -----------------------------
def create_user(self, user: UserRequest) -> UserResponse:
"""Creates a new user.
Args:
user: User to be created.
Returns:
The newly created user.
"""
return self._create_resource(
resource=user,
route=USERS,
response_model=UserResponse,
)
def get_user(
self,
user_name_or_id: Optional[Union[str, UUID]] = None,
include_private: bool = False,
hydrate: bool = True,
) -> UserResponse:
"""Gets a specific user, when no id is specified get the active user.
The `include_private` parameter is ignored here as it is handled
implicitly by the /current-user endpoint that is queried when no
user_name_or_id is set. Raises a KeyError in case a user with that id
does not exist.
Args:
user_name_or_id: The name or ID of the user to get.
include_private: Whether to include private user information.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested user, if it was found.
"""
if user_name_or_id:
return self._get_resource(
resource_id=user_name_or_id,
route=USERS,
response_model=UserResponse,
params={"hydrate": hydrate},
)
else:
body = self.get(CURRENT_USER, params={"hydrate": hydrate})
return UserResponse.model_validate(body)
def list_users(
self,
user_filter_model: UserFilter,
hydrate: bool = False,
) -> Page[UserResponse]:
"""List all users.
Args:
user_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all users.
"""
return self._list_paginated_resources(
route=USERS,
response_model=UserResponse,
filter_model=user_filter_model,
params={"hydrate": hydrate},
)
def update_user(
self, user_id: UUID, user_update: UserUpdate
) -> UserResponse:
"""Updates an existing user.
Args:
user_id: The id of the user to update.
user_update: The update to be applied to the user.
Returns:
The updated user.
"""
return self._update_resource(
resource_id=user_id,
resource_update=user_update,
route=USERS,
response_model=UserResponse,
)
def deactivate_user(
self, user_name_or_id: Union[str, UUID]
) -> UserResponse:
"""Deactivates a user.
Args:
user_name_or_id: The name or ID of the user to delete.
Returns:
The deactivated user containing the activation token.
"""
response_body = self.put(
f"{USERS}/{str(user_name_or_id)}{DEACTIVATE}",
)
return UserResponse.model_validate(response_body)
def delete_user(self, user_name_or_id: Union[str, UUID]) -> None:
"""Deletes a user.
Args:
user_name_or_id: The name or ID of the user to delete.
"""
self._delete_resource(
resource_id=user_name_or_id,
route=USERS,
)
# ----------------------------- Workspaces -----------------------------
def create_workspace(
self, workspace: WorkspaceRequest
) -> WorkspaceResponse:
"""Creates a new workspace.
Args:
workspace: The workspace to create.
Returns:
The newly created workspace.
"""
return self._create_resource(
resource=workspace,
route=WORKSPACES,
response_model=WorkspaceResponse,
)
def get_workspace(
self, workspace_name_or_id: Union[UUID, str], hydrate: bool = True
) -> WorkspaceResponse:
"""Get an existing workspace by name or ID.
Args:
workspace_name_or_id: Name or ID of the workspace to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested workspace.
"""
return self._get_resource(
resource_id=workspace_name_or_id,
route=WORKSPACES,
response_model=WorkspaceResponse,
params={"hydrate": hydrate},
)
def list_workspaces(
self,
workspace_filter_model: WorkspaceFilter,
hydrate: bool = False,
) -> Page[WorkspaceResponse]:
"""List all workspace matching the given filter criteria.
Args:
workspace_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A list of all workspace matching the filter criteria.
"""
return self._list_paginated_resources(
route=WORKSPACES,
response_model=WorkspaceResponse,
filter_model=workspace_filter_model,
params={"hydrate": hydrate},
)
def update_workspace(
self, workspace_id: UUID, workspace_update: WorkspaceUpdate
) -> WorkspaceResponse:
"""Update an existing workspace.
Args:
workspace_id: The ID of the workspace to be updated.
workspace_update: The update to be applied to the workspace.
Returns:
The updated workspace.
"""
return self._update_resource(
resource_id=workspace_id,
resource_update=workspace_update,
route=WORKSPACES,
response_model=WorkspaceResponse,
)
def delete_workspace(self, workspace_name_or_id: Union[str, UUID]) -> None:
"""Deletes a workspace.
Args:
workspace_name_or_id: Name or ID of the workspace to delete.
"""
self._delete_resource(
resource_id=workspace_name_or_id,
route=WORKSPACES,
)
# --------------------------- Model ---------------------------
def create_model(self, model: ModelRequest) -> ModelResponse:
"""Creates a new model.
Args:
model: the Model to be created.
Returns:
The newly created model.
"""
return self._create_workspace_scoped_resource(
resource=model,
response_model=ModelResponse,
route=MODELS,
)
def delete_model(self, model_name_or_id: Union[str, UUID]) -> None:
"""Deletes a model.
Args:
model_name_or_id: name or id of the model to be deleted.
"""
self._delete_resource(resource_id=model_name_or_id, route=MODELS)
def update_model(
self,
model_id: UUID,
model_update: ModelUpdate,
) -> ModelResponse:
"""Updates an existing model.
Args:
model_id: UUID of the model to be updated.
model_update: the Model to be updated.
Returns:
The updated model.
"""
return self._update_resource(
resource_id=model_id,
resource_update=model_update,
route=MODELS,
response_model=ModelResponse,
)
def get_model(
self, model_name_or_id: Union[str, UUID], hydrate: bool = True
) -> ModelResponse:
"""Get an existing model.
Args:
model_name_or_id: name or id of the model to be retrieved.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model of interest.
"""
return self._get_resource(
resource_id=model_name_or_id,
route=MODELS,
response_model=ModelResponse,
params={"hydrate": hydrate},
)
def list_models(
self,
model_filter_model: ModelFilter,
hydrate: bool = False,
) -> Page[ModelResponse]:
"""Get all models by filter.
Args:
model_filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all models.
"""
return self._list_paginated_resources(
route=MODELS,
response_model=ModelResponse,
filter_model=model_filter_model,
params={"hydrate": hydrate},
)
# ----------------------------- Model Versions -----------------------------
def create_model_version(
self, model_version: ModelVersionRequest
) -> ModelVersionResponse:
"""Creates a new model version.
Args:
model_version: the Model Version to be created.
Returns:
The newly created model version.
"""
return self._create_workspace_scoped_resource(
resource=model_version,
response_model=ModelVersionResponse,
route=f"{MODELS}/{model_version.model}{MODEL_VERSIONS}",
)
def delete_model_version(
self,
model_version_id: UUID,
) -> None:
"""Deletes a model version.
Args:
model_version_id: name or id of the model version to be deleted.
"""
self._delete_resource(
resource_id=model_version_id,
route=f"{MODEL_VERSIONS}",
)
def get_model_version(
self, model_version_id: UUID, hydrate: bool = True
) -> ModelVersionResponse:
"""Get an existing model version.
Args:
model_version_id: name, id, stage or number of the model version to
be retrieved. If skipped - latest is retrieved.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model version of interest.
"""
return self._get_resource(
resource_id=model_version_id,
route=MODEL_VERSIONS,
response_model=ModelVersionResponse,
params={"hydrate": hydrate},
)
def list_model_versions(
self,
model_version_filter_model: ModelVersionFilter,
model_name_or_id: Optional[Union[str, UUID]] = None,
hydrate: bool = False,
) -> Page[ModelVersionResponse]:
"""Get all model versions by filter.
Args:
model_name_or_id: name or id of the model containing the model
versions.
model_version_filter_model: All filter parameters including
pagination params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all model versions.
"""
if model_name_or_id:
return self._list_paginated_resources(
route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}",
response_model=ModelVersionResponse,
filter_model=model_version_filter_model,
params={"hydrate": hydrate},
)
else:
return self._list_paginated_resources(
route=MODEL_VERSIONS,
response_model=ModelVersionResponse,
filter_model=model_version_filter_model,
params={"hydrate": hydrate},
)
def update_model_version(
self,
model_version_id: UUID,
model_version_update_model: ModelVersionUpdate,
) -> ModelVersionResponse:
"""Get all model versions by filter.
Args:
model_version_id: The ID of model version to be updated.
model_version_update_model: The model version to be updated.
Returns:
An updated model version.
"""
return self._update_resource(
resource_id=model_version_id,
resource_update=model_version_update_model,
route=MODEL_VERSIONS,
response_model=ModelVersionResponse,
)
# ------------------------ Model Versions Artifacts ------------------------
def create_model_version_artifact_link(
self, model_version_artifact_link: ModelVersionArtifactRequest
) -> ModelVersionArtifactResponse:
"""Creates a new model version link.
Args:
model_version_artifact_link: the Model Version to Artifact Link
to be created.
Returns:
The newly created model version to artifact link.
"""
return self._create_workspace_scoped_resource(
resource=model_version_artifact_link,
response_model=ModelVersionArtifactResponse,
route=f"{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}",
)
def list_model_version_artifact_links(
self,
model_version_artifact_link_filter_model: ModelVersionArtifactFilter,
hydrate: bool = False,
) -> Page[ModelVersionArtifactResponse]:
"""Get all model version to artifact links by filter.
Args:
model_version_artifact_link_filter_model: All filter parameters
including pagination params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all model version to artifact links.
"""
return self._list_paginated_resources(
route=MODEL_VERSION_ARTIFACTS,
response_model=ModelVersionArtifactResponse,
filter_model=model_version_artifact_link_filter_model,
params={"hydrate": hydrate},
)
def delete_model_version_artifact_link(
self,
model_version_id: UUID,
model_version_artifact_link_name_or_id: Union[str, UUID],
) -> None:
"""Deletes a model version to artifact link.
Args:
model_version_id: ID of the model version containing the link.
model_version_artifact_link_name_or_id: name or ID of the model
version to artifact link to be deleted.
"""
self._delete_resource(
resource_id=model_version_artifact_link_name_or_id,
route=f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}",
)
def delete_all_model_version_artifact_links(
self,
model_version_id: UUID,
only_links: bool = True,
) -> None:
"""Deletes all links between model version and an artifact.
Args:
model_version_id: ID of the model version containing the link.
only_links: Flag deciding whether to delete only links or all.
"""
self.delete(
f"{MODEL_VERSIONS}/{model_version_id}{ARTIFACTS}",
params={"only_links": only_links},
)
# ---------------------- Model Versions Pipeline Runs ----------------------
def create_model_version_pipeline_run_link(
self,
model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
) -> ModelVersionPipelineRunResponse:
"""Creates a new model version to pipeline run link.
Args:
model_version_pipeline_run_link: the Model Version to Pipeline Run
Link to be created.
Returns:
- If Model Version to Pipeline Run Link already exists - returns
the existing link.
- Otherwise, returns the newly created model version to pipeline
run link.
"""
return self._create_workspace_scoped_resource(
resource=model_version_pipeline_run_link,
response_model=ModelVersionPipelineRunResponse,
route=f"{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}",
)
def list_model_version_pipeline_run_links(
self,
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilter,
hydrate: bool = False,
) -> Page[ModelVersionPipelineRunResponse]:
"""Get all model version to pipeline run links by filter.
Args:
model_version_pipeline_run_link_filter_model: All filter parameters
including pagination params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all model version to pipeline run links.
"""
return self._list_paginated_resources(
route=MODEL_VERSION_PIPELINE_RUNS,
response_model=ModelVersionPipelineRunResponse,
filter_model=model_version_pipeline_run_link_filter_model,
params={"hydrate": hydrate},
)
def delete_model_version_pipeline_run_link(
self,
model_version_id: UUID,
model_version_pipeline_run_link_name_or_id: Union[str, UUID],
) -> None:
"""Deletes a model version to pipeline run link.
Args:
model_version_id: ID of the model version containing the link.
model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted.
"""
self._delete_resource(
resource_id=model_version_pipeline_run_link_name_or_id,
route=f"{MODEL_VERSIONS}/{model_version_id}{RUNS}",
)
# ---------------------------- Devices ----------------------------
def get_authorized_device(
self, device_id: UUID, hydrate: bool = True
) -> OAuthDeviceResponse:
"""Gets a specific OAuth 2.0 authorized device.
Args:
device_id: The ID of the device to get.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The requested device, if it was found.
"""
return self._get_resource(
resource_id=device_id,
route=DEVICES,
response_model=OAuthDeviceResponse,
params={"hydrate": hydrate},
)
def list_authorized_devices(
self, filter_model: OAuthDeviceFilter, hydrate: bool = False
) -> Page[OAuthDeviceResponse]:
"""List all OAuth 2.0 authorized devices for a user.
Args:
filter_model: All filter parameters including pagination
params.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
A page of all matching OAuth 2.0 authorized devices.
"""
return self._list_paginated_resources(
route=DEVICES,
response_model=OAuthDeviceResponse,
filter_model=filter_model,
params={"hydrate": hydrate},
)
def update_authorized_device(
self, device_id: UUID, update: OAuthDeviceUpdate
) -> OAuthDeviceResponse:
"""Updates an existing OAuth 2.0 authorized device for internal use.
Args:
device_id: The ID of the device to update.
update: The update to be applied to the device.
Returns:
The updated OAuth 2.0 authorized device.
"""
return self._update_resource(
resource_id=device_id,
resource_update=update,
response_model=OAuthDeviceResponse,
route=DEVICES,
)
def delete_authorized_device(self, device_id: UUID) -> None:
"""Deletes an OAuth 2.0 authorized device.
Args:
device_id: The ID of the device to delete.
"""
self._delete_resource(resource_id=device_id, route=DEVICES)
# -------------------
# Pipeline API Tokens
# -------------------
def get_api_token(
self,
pipeline_id: Optional[UUID] = None,
schedule_id: Optional[UUID] = None,
expires_minutes: Optional[int] = None,
) -> str:
"""Get an API token for a workload.
Args:
pipeline_id: The ID of the pipeline to get a token for.
schedule_id: The ID of the schedule to get a token for.
expires_minutes: The number of minutes for which the token should
be valid. If not provided, the token will be valid indefinitely.
Returns:
The API token.
Raises:
ValueError: if the server response is not valid.
"""
params: Dict[str, Any] = {}
if pipeline_id:
params["pipeline_id"] = pipeline_id
if schedule_id:
params["schedule_id"] = schedule_id
if expires_minutes:
params["expires_minutes"] = expires_minutes
response_body = self.get(API_TOKEN, params=params)
if not isinstance(response_body, str):
raise ValueError(
f"Bad API Response. Expected API token, got "
f"{type(response_body)}"
)
return response_body
#################
# Tags
#################
def create_tag(self, tag: TagRequest) -> TagResponse:
"""Creates a new tag.
Args:
tag: the tag to be created.
Returns:
The newly created tag.
"""
return self._create_resource(
resource=tag,
response_model=TagResponse,
route=TAGS,
)
def delete_tag(
self,
tag_name_or_id: Union[str, UUID],
) -> None:
"""Deletes a tag.
Args:
tag_name_or_id: name or id of the tag to delete.
"""
self._delete_resource(resource_id=tag_na