Zen Stores
zenml.zen_stores
special
ZenStores define ways to store ZenML relevant data locally or remotely.
base_zen_store
BaseZenStore (ABC)
Base class for accessing data in ZenML Repository and new Service.
Source code in zenml/zen_stores/base_zen_store.py
class BaseZenStore(ABC):
"""Base class for accessing data in ZenML Repository and new Service."""
def initialize(
self,
url: str,
skip_default_registrations: bool = False,
track_analytics: bool = True,
skip_migration: bool = False,
*args: Any,
**kwargs: Any,
) -> "BaseZenStore":
"""Initialize the store.
Args:
url: The URL of the store.
skip_default_registrations: If `True`, the creation of the default
stack and user will be skipped.
track_analytics: Only send analytics if set to `True`.
skip_migration: If `True`, no migration will be performed.
*args: Additional arguments to pass to the concrete store
implementation.
**kwargs: Additional keyword arguments to pass to the concrete
store implementation.
Returns:
The initialized concrete store instance.
"""
self._track_analytics = track_analytics
if not skip_default_registrations:
if self.stacks_empty:
logger.info("Registering default stack...")
self.register_default_stack()
self.create_default_user()
if not skip_migration:
self._migrate_store()
return self
def _migrate_store(self) -> None:
"""Migrates the store to the latest version."""
# Older versions of ZenML didn't have users in the zen store, so we
# create the default user if it doesn't exists
self.create_default_user()
# Static methods:
@staticmethod
@abstractmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the path from a URL, if it points or is backed by a local file.
Args:
url: The URL to get the path from.
Returns:
The local path backed by the URL, or None if the URL is not backed
by a local file or directory
"""
@staticmethod
@abstractmethod
def get_local_url(path: str) -> str:
"""Get a local URL for a given local path.
Args:
path: the path string to build a URL out of.
Returns:
Url pointing to the path for the store type.
"""
@staticmethod
@abstractmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is valid."""
# Public Interface:
@property
@abstractmethod
def type(self) -> StoreType:
"""The type of zen store."""
@property
@abstractmethod
def url(self) -> str:
"""Get the repository URL."""
@property
@abstractmethod
def stacks_empty(self) -> bool:
"""Check if the store is empty (no stacks are configured).
The implementation of this method should check if the store is empty
without having to load all the stacks from the persistent storage.
"""
@abstractmethod
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
@property
@abstractmethod
def stack_configurations(self) -> Dict[str, Dict[StackComponentType, str]]:
"""Configurations for all stacks registered in this zen store.
Returns:
Dictionary mapping stack names to Dict[StackComponentType, str]'s
"""
# Private interface (must be implemented, not to be called by user):
@abstractmethod
def _register_stack_component(
self,
component: ComponentWrapper,
) -> None:
"""Register a stack component.
Args:
component: The component to register.
Raises:
StackComponentExistsError: If a stack component with the same type
and name already exists.
"""
@abstractmethod
def _update_stack_component(
self,
name: str,
component_type: StackComponentType,
component: ComponentWrapper,
) -> Dict[str, str]:
"""Update a stack component.
Args:
name: The original name of the stack component.
component_type: The type of the stack component to update.
component: The new component to update with.
Raises:
KeyError: If no stack component exists with the given name.
"""
@abstractmethod
def _deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
Raises:
KeyError: If no stack exists for the given name.
"""
@abstractmethod
def _save_stack(
self,
name: str,
stack_configuration: Dict[StackComponentType, str],
) -> None:
"""Add a stack to storage.
Args:
name: The name to save the stack as.
stack_configuration: Dict[StackComponentType, str] to persist.
"""
@abstractmethod
def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
) -> Tuple[str, bytes]:
"""Fetch the flavor and configuration for a stack component.
Args:
component_type: The type of the component to fetch.
name: The name of the component to fetch.
Returns:
Pair of (flavor, configuration) for stack component, as string and
base64-encoded yaml document, respectively
Raises:
KeyError: If no stack component exists for the given type and name.
"""
@abstractmethod
def _get_stack_component_names(
self, component_type: StackComponentType
) -> List[str]:
"""Get names of all registered stack components of a given type.
Args:
component_type: The type of the component to list names for.
Returns:
A list of names as strings.
"""
@abstractmethod
def _delete_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Remove a StackComponent from storage.
Args:
component_type: The type of component to delete.
name: Then name of the component to delete.
Raises:
KeyError: If no component exists for given type and name.
"""
# User, project and role management
@property
@abstractmethod
def users(self) -> List[User]:
"""All registered users.
Returns:
A list of all registered users.
"""
@abstractmethod
def _get_user(self, user_name: str) -> User:
"""Get a specific user by name.
Args:
user_name: Name of the user to get.
Returns:
The requested user, if it was found.
Raises:
KeyError: If no user with the given name exists.
"""
@abstractmethod
def _create_user(self, user_name: str) -> User:
"""Creates a new user.
Args:
user_name: Unique username.
Returns:
The newly created user.
Raises:
EntityExistsError: If a user with the given name already exists.
"""
@abstractmethod
def _delete_user(self, user_name: str) -> None:
"""Deletes a user.
Args:
user_name: Name of the user to delete.
Raises:
KeyError: If no user with the given name exists.
"""
@property
@abstractmethod
def teams(self) -> List[Team]:
"""All registered teams.
Returns:
A list of all registered teams.
"""
@abstractmethod
def _create_team(self, team_name: str) -> Team:
"""Creates a new team.
Args:
team_name: Unique team name.
Returns:
The newly created team.
Raises:
EntityExistsError: If a team with the given name already exists.
"""
@abstractmethod
def _get_team(self, team_name: str) -> Team:
"""Gets a specific team.
Args:
team_name: Name of the team to get.
Returns:
The requested team.
Raises:
KeyError: If no team with the given name exists.
"""
@abstractmethod
def _delete_team(self, team_name: str) -> None:
"""Deletes a team.
Args:
team_name: Name of the team to delete.
Raises:
KeyError: If no team with the given name exists.
"""
@abstractmethod
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
@abstractmethod
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
@property
@abstractmethod
def projects(self) -> List[Project]:
"""All registered projects.
Returns:
A list of all registered projects.
"""
@abstractmethod
def _get_project(self, project_name: str) -> Project:
"""Get an existing project by name.
Args:
project_name: Name of the project to get.
Returns:
The requested project if one was found.
Raises:
KeyError: If there is no such project.
"""
@abstractmethod
def _create_project(
self, project_name: str, description: Optional[str] = None
) -> Project:
"""Creates a new project.
Args:
project_name: Unique project name.
description: Optional project description.
Returns:
The newly created project.
Raises:
EntityExistsError: If a project with the given name already exists.
"""
@abstractmethod
def _delete_project(self, project_name: str) -> None:
"""Deletes a project.
Args:
project_name: Name of the project to delete.
Raises:
KeyError: If no project with the given name exists.
"""
@property
@abstractmethod
def roles(self) -> List[Role]:
"""All registered roles.
Returns:
A list of all registered roles.
"""
@property
@abstractmethod
def role_assignments(self) -> List[RoleAssignment]:
"""All registered role assignments.
Returns:
A list of all registered role assignments.
"""
@abstractmethod
def _get_role(self, role_name: str) -> Role:
"""Gets a specific role.
Args:
role_name: Name of the role to get.
Returns:
The requested role.
Raises:
KeyError: If no role with the given name exists.
"""
@abstractmethod
def _create_role(self, role_name: str) -> Role:
"""Creates a new role.
Args:
role_name: Unique role name.
Returns:
The newly created role.
Raises:
EntityExistsError: If a role with the given name already exists.
"""
@abstractmethod
def _delete_role(self, role_name: str) -> None:
"""Deletes a role.
Args:
role_name: Name of the role to delete.
Raises:
KeyError: If no role with the given name exists.
"""
@abstractmethod
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
@abstractmethod
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
@abstractmethod
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
"""
@abstractmethod
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
"""
@abstractmethod
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
"""
@abstractmethod
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no team or project with the given names exists.
"""
# Pipelines and pipeline runs
@abstractmethod
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
@abstractmethod
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
@abstractmethod
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
# Stack component flavors
@property
@abstractmethod
def flavors(self) -> List[FlavorWrapper]:
"""All registered flavors.
Returns:
A list of all registered flavors.
"""
@abstractmethod
def _create_flavor(
self,
source: str,
name: str,
stack_component_type: StackComponentType,
) -> FlavorWrapper:
"""Creates a new flavor.
Args:
source: the source path to the implemented flavor.
name: the name of the flavor.
stack_component_type: the corresponding StackComponentType.
Returns:
The newly created flavor.
Raises:
EntityExistsError: If a flavor with the given name and type
already exists.
"""
@abstractmethod
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
@abstractmethod
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
# Common code (user facing):
@property
def stacks(self) -> List[StackWrapper]:
"""All stacks registered in this zen store."""
return [
self._stack_from_dict(name, conf)
for name, conf in self.stack_configurations.items()
]
def get_stack(self, name: str) -> StackWrapper:
"""Fetch a stack by name.
Args:
name: The name of the stack to retrieve.
Returns:
StackWrapper instance if the stack exists.
Raises:
KeyError: If no stack exists for the given name.
"""
return self._stack_from_dict(name, self.get_stack_configuration(name))
def _register_stack(self, stack: StackWrapper) -> None:
"""Register a stack and its components.
If any of the stack's components aren't registered in the zen store
yet, this method will try to register them as well.
Args:
stack: The stack to register.
Raises:
StackExistsError: If a stack with the same name already exists.
StackComponentExistsError: If a component of the stack wasn't
registered and a different component with the same name
already exists.
"""
try:
self.get_stack(stack.name)
except KeyError:
pass
else:
raise StackExistsError(
f"Unable to register stack with name '{stack.name}': Found "
f"existing stack with this name."
)
def __check_component(
component: ComponentWrapper,
) -> Tuple[StackComponentType, str]:
"""Try to register a stack component, if it doesn't exist.
Args:
component: StackComponentWrapper to register.
Raises:
StackComponentExistsError: If a component with same name exists.
"""
try:
existing_component = self.get_stack_component(
component_type=component.type, name=component.name
)
if existing_component.uuid != component.uuid:
raise StackComponentExistsError(
f"Unable to register one of the stacks components: "
f"A component of type '{component.type}' and name "
f"'{component.name}' already exists."
)
except KeyError:
self._register_stack_component(component)
return component.type, component.name
stack_configuration = {
typ: name for typ, name in map(__check_component, stack.components)
}
self._save_stack(stack.name, stack_configuration)
logger.info("Registered stack with name '%s'.", stack.name)
def _update_stack(self, name: str, stack: StackWrapper) -> None:
"""Update a stack and its components.
If any of the stack's components aren't registered in the stack store
yet, this method will try to register them as well.
Args:
name: The original name of the stack.
stack: The new stack to use in the update.
Raises:
DoesNotExistException: If no stack exists with the given name.
"""
try:
self.get_stack(name)
except KeyError:
raise KeyError(
f"Unable to update stack with name '{stack.name}': No existing "
f"stack found with this name."
)
try:
renamed_stack = self.get_stack(stack.name)
if (name != stack.name) and renamed_stack:
raise StackExistsError(
f"Unable to update stack with name '{stack.name}': Found "
f"existing stack with this name."
)
except KeyError:
pass
def __check_component(
component: ComponentWrapper,
) -> Tuple[StackComponentType, str]:
try:
_ = self.get_stack_component(
component_type=component.type, name=component.name
)
except KeyError:
self._register_stack_component(component)
return component.type, component.name
stack_configuration = {
typ: name for typ, name in map(__check_component, stack.components)
}
self._save_stack(stack.name, stack_configuration)
logger.info("Updated stack with name '%s'.", name)
if name != stack.name:
self.deregister_stack(name)
def get_stack_component(
self, component_type: StackComponentType, name: str
) -> ComponentWrapper:
"""Get a registered stack component.
Raises:
KeyError: If no component with the requested type and name exists.
"""
flavor, config = self._get_component_flavor_and_config(
component_type, name=name
)
uuid = yaml.safe_load(base64.b64decode(config).decode())["uuid"]
return ComponentWrapper(
type=component_type,
flavor=flavor,
name=name,
uuid=uuid,
config=config,
)
def get_stack_components(
self, component_type: StackComponentType
) -> List[ComponentWrapper]:
"""Fetches all registered stack components of the given type.
Args:
component_type: StackComponentType to list members of
Returns:
A list of StackComponentConfiguration instances.
"""
return [
self.get_stack_component(component_type=component_type, name=name)
for name in self._get_stack_component_names(component_type)
]
def deregister_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Deregisters a stack component.
Args:
component_type: The type of the component to deregister.
name: The name of the component to deregister.
Raises:
ValueError: if trying to deregister a component that's part
of a stack.
"""
for stack_name, stack_config in self.stack_configurations.items():
if stack_config.get(component_type) == name:
raise ValueError(
f"Unable to deregister stack component (type: "
f"{component_type}, name: {name}) that is part of a "
f"registered stack (stack name: '{stack_name}')."
)
self._delete_stack_component(component_type, name=name)
def register_default_stack(self) -> None:
"""Populates the store with the default Stack.
The default stack contains a local orchestrator,
a local artifact store and a local SQLite metadata store.
"""
stack = Stack.default_local_stack()
sw = StackWrapper.from_stack(stack)
self._register_stack(sw)
metadata = {c.type.value: c.flavor for c in sw.components}
metadata["store_type"] = self.type.value
self._track_event(
AnalyticsEvent.REGISTERED_DEFAULT_STACK, metadata=metadata
)
def create_default_user(self) -> None:
"""Creates a default user."""
try:
self.get_user(user_name=DEFAULT_USERNAME)
except KeyError:
# Use private interface and send custom tracking event
self._track_event(AnalyticsEvent.CREATED_DEFAULT_USER)
self._create_user(user_name=DEFAULT_USERNAME)
# Common code (internal implementations, private):
def _track_event(
self,
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
if self._track_analytics:
return track_event(event, metadata)
return False
def _stack_from_dict(
self, name: str, stack_configuration: Dict[StackComponentType, str]
) -> StackWrapper:
"""Build a StackWrapper from stored configurations"""
stack_components = [
self.get_stack_component(
component_type=component_type, name=component_name
)
for component_type, component_name in stack_configuration.items()
]
return StackWrapper(name=name, components=stack_components)
# Public facing APIs
# TODO [ENG-894]: Refactor these with the proxy pattern, as noted in
# the [review comment](https://github.com/zenml-io/zenml/pull/589#discussion_r875003334)
def register_stack_component(
self,
component: ComponentWrapper,
) -> None:
"""Register a stack component.
Args:
component: The component to register.
Raises:
StackComponentExistsError: If a stack component with the same type
and name already exists.
"""
analytics_metadata = {
"type": component.type.value,
"flavor": component.flavor,
}
self._track_event(
AnalyticsEvent.REGISTERED_STACK_COMPONENT,
metadata=analytics_metadata,
)
return self._register_stack_component(component)
def update_stack_component(
self,
name: str,
component_type: StackComponentType,
component: ComponentWrapper,
) -> Dict[str, str]:
"""Update a stack component.
Args:
name: The original name of the stack component.
component_type: The type of the stack component to update.
component: The new component to update with.
Raises:
KeyError: If no stack component exists with the given name.
"""
analytics_metadata = {
"type": component.type.value,
"flavor": component.flavor,
}
self._track_event(
AnalyticsEvent.UPDATED_STACK_COMPONENT,
metadata=analytics_metadata,
)
return self._update_stack_component(name, component_type, component)
def deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
Raises:
KeyError: If no stack exists for the given name.
"""
# No tracking events, here for consistency
return self._deregister_stack(name)
def create_user(self, user_name: str) -> User:
"""Creates a new user.
Args:
user_name: Unique username.
Returns:
The newly created user.
Raises:
EntityExistsError: If a user with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_USER)
return self._create_user(user_name)
def delete_user(self, user_name: str) -> None:
"""Deletes a user.
Args:
user_name: Name of the user to delete.
Raises:
KeyError: If no user with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_USER)
return self._delete_user(user_name)
def get_user(self, user_name: str) -> User:
"""Gets a specific user.
Args:
user_name: Name of the user to get.
Returns:
The requested user.
Raises:
KeyError: If no user with the given name exists.
"""
# No tracking events, here for consistency
return self._get_user(user_name)
def create_team(self, team_name: str) -> Team:
"""Creates a new team.
Args:
team_name: Unique team name.
Returns:
The newly created team.
Raises:
EntityExistsError: If a team with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_TEAM)
return self._create_team(team_name)
def get_team(self, team_name: str) -> Team:
"""Gets a specific team.
Args:
team_name: Name of the team to get.
Returns:
The requested team.
Raises:
KeyError: If no team with the given name exists.
"""
# No tracking events, here for consistency
return self._get_team(team_name)
def delete_team(self, team_name: str) -> None:
"""Deletes a team.
Args:
team_name: Name of the team to delete.
Raises:
KeyError: If no team with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_TEAM)
return self._delete_team(team_name)
def get_project(self, project_name: str) -> Project:
"""Gets a specific project.
Args:
project_name: Name of the project to get.
Returns:
The requested project.
Raises:
KeyError: If no project with the given name exists.
"""
# No tracking events, here for consistency
return self._get_project(project_name)
def create_project(
self, project_name: str, description: Optional[str] = None
) -> Project:
"""Creates a new project.
Args:
project_name: Unique project name.
description: Optional project description.
Returns:
The newly created project.
Raises:
EntityExistsError: If a project with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_PROJECT)
return self._create_project(project_name, description)
def delete_project(self, project_name: str) -> None:
"""Deletes a project.
Args:
project_name: Name of the project to delete.
Raises:
KeyError: If no project with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_PROJECT)
return self._delete_project(project_name)
def get_role(self, role_name: str) -> Role:
"""Gets a specific role.
Args:
role_name: Name of the role to get.
Returns:
The requested role.
Raises:
KeyError: If no role with the given name exists.
"""
# No tracking events, here for consistency
return self._get_role(role_name)
def create_role(self, role_name: str) -> Role:
"""Creates a new role.
Args:
role_name: Unique role name.
Returns:
The newly created role.
Raises:
EntityExistsError: If a role with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_ROLE)
return self._create_role(role_name)
def delete_role(self, role_name: str) -> None:
"""Deletes a role.
Args:
role_name: Name of the role to delete.
Raises:
KeyError: If no role with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_ROLE)
return self._delete_role(role_name)
def create_flavor(
self,
source: str,
name: str,
stack_component_type: StackComponentType,
) -> FlavorWrapper:
"""Creates a new flavor.
Args:
source: the source path to the implemented flavor.
name: the name of the flavor.
stack_component_type: the corresponding StackComponentType.
Returns:
The newly created flavor.
Raises:
EntityExistsError: If a flavor with the given name and type
already exists.
"""
analytics_metadata = {
"type": stack_component_type.value,
}
track_event(
AnalyticsEvent.CREATED_FLAVOR,
metadata=analytics_metadata,
)
return self._create_flavor(source, name, stack_component_type)
def register_stack(self, stack: StackWrapper) -> None:
"""Register a stack and its components.
If any of the stack's components aren't registered in the zen store
yet, this method will try to register them as well.
Args:
stack: The stack to register.
Raises:
StackExistsError: If a stack with the same name already exists.
StackComponentExistsError: If a component of the stack wasn't
registered and a different component with the same name
already exists.
"""
metadata = {c.type.value: c.flavor for c in stack.components}
metadata["store_type"] = self.type.value
track_event(AnalyticsEvent.REGISTERED_STACK, metadata=metadata)
return self._register_stack(stack)
def update_stack(self, name: str, stack: StackWrapper) -> None:
"""Update a stack and its components.
If any of the stack's components aren't registered in the stack store
yet, this method will try to register them as well.
Args:
name: The original name of the stack.
stack: The new stack to use in the update.
Raises:
DoesNotExistException: If no stack exists with the given name.
"""
metadata = {c.type.value: c.flavor for c in stack.components}
metadata["store_type"] = self.type.value
track_event(AnalyticsEvent.UPDATED_STACK, metadata=metadata)
return self._update_stack(name, stack)
flavors: List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]
property
readonly
All registered flavors.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
A list of all registered flavors. |
projects: List[zenml.zen_stores.models.user_management_models.Project]
property
readonly
All registered projects.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Project] |
A list of all registered projects. |
role_assignments: List[zenml.zen_stores.models.user_management_models.RoleAssignment]
property
readonly
All registered role assignments.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
A list of all registered role assignments. |
roles: List[zenml.zen_stores.models.user_management_models.Role]
property
readonly
All registered roles.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Role] |
A list of all registered roles. |
stack_configurations: Dict[str, Dict[zenml.enums.StackComponentType, str]]
property
readonly
Configurations for all stacks registered in this zen store.
Returns:
Type | Description |
---|---|
Dict[str, Dict[zenml.enums.StackComponentType, str]] |
Dictionary mapping stack names to Dict[StackComponentType, str]'s |
stacks: List[zenml.zen_stores.models.stack_wrapper.StackWrapper]
property
readonly
All stacks registered in this zen store.
stacks_empty: bool
property
readonly
Check if the store is empty (no stacks are configured).
The implementation of this method should check if the store is empty without having to load all the stacks from the persistent storage.
teams: List[zenml.zen_stores.models.user_management_models.Team]
property
readonly
All registered teams.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
A list of all registered teams. |
type: StoreType
property
readonly
The type of zen store.
url: str
property
readonly
Get the repository URL.
users: List[zenml.zen_stores.models.user_management_models.User]
property
readonly
All registered users.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
A list of all registered users. |
add_user_to_team(self, team_name, user_name)
Adds a user to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
assign_role(self, role_name, entity_name, project_name=None, is_user=True)
Assigns a role to a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to assign. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
create_default_user(self)
Creates a default user.
Source code in zenml/zen_stores/base_zen_store.py
def create_default_user(self) -> None:
"""Creates a default user."""
try:
self.get_user(user_name=DEFAULT_USERNAME)
except KeyError:
# Use private interface and send custom tracking event
self._track_event(AnalyticsEvent.CREATED_DEFAULT_USER)
self._create_user(user_name=DEFAULT_USERNAME)
create_flavor(self, source, name, stack_component_type)
Creates a new flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
the source path to the implemented flavor. |
required |
name |
str |
the name of the flavor. |
required |
stack_component_type |
StackComponentType |
the corresponding StackComponentType. |
required |
Returns:
Type | Description |
---|---|
FlavorWrapper |
The newly created flavor. |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a flavor with the given name and type already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def create_flavor(
self,
source: str,
name: str,
stack_component_type: StackComponentType,
) -> FlavorWrapper:
"""Creates a new flavor.
Args:
source: the source path to the implemented flavor.
name: the name of the flavor.
stack_component_type: the corresponding StackComponentType.
Returns:
The newly created flavor.
Raises:
EntityExistsError: If a flavor with the given name and type
already exists.
"""
analytics_metadata = {
"type": stack_component_type.value,
}
track_event(
AnalyticsEvent.CREATED_FLAVOR,
metadata=analytics_metadata,
)
return self._create_flavor(source, name, stack_component_type)
create_project(self, project_name, description=None)
Creates a new project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project_name |
str |
Unique project name. |
required |
description |
Optional[str] |
Optional project description. |
None |
Returns:
Type | Description |
---|---|
Project |
The newly created project. |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a project with the given name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def create_project(
self, project_name: str, description: Optional[str] = None
) -> Project:
"""Creates a new project.
Args:
project_name: Unique project name.
description: Optional project description.
Returns:
The newly created project.
Raises:
EntityExistsError: If a project with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_PROJECT)
return self._create_project(project_name, description)
create_role(self, role_name)
Creates a new role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Unique role name. |
required |
Returns:
Type | Description |
---|---|
Role |
The newly created role. |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a role with the given name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def create_role(self, role_name: str) -> Role:
"""Creates a new role.
Args:
role_name: Unique role name.
Returns:
The newly created role.
Raises:
EntityExistsError: If a role with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_ROLE)
return self._create_role(role_name)
create_team(self, team_name)
Creates a new team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Unique team name. |
required |
Returns:
Type | Description |
---|---|
Team |
The newly created team. |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a team with the given name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def create_team(self, team_name: str) -> Team:
"""Creates a new team.
Args:
team_name: Unique team name.
Returns:
The newly created team.
Raises:
EntityExistsError: If a team with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_TEAM)
return self._create_team(team_name)
create_user(self, user_name)
Creates a new user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Unique username. |
required |
Returns:
Type | Description |
---|---|
User |
The newly created user. |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a user with the given name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def create_user(self, user_name: str) -> User:
"""Creates a new user.
Args:
user_name: Unique username.
Returns:
The newly created user.
Raises:
EntityExistsError: If a user with the given name already exists.
"""
self._track_event(AnalyticsEvent.CREATED_USER)
return self._create_user(user_name)
delete_project(self, project_name)
Deletes a project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project_name |
str |
Name of the project to delete. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no project with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def delete_project(self, project_name: str) -> None:
"""Deletes a project.
Args:
project_name: Name of the project to delete.
Raises:
KeyError: If no project with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_PROJECT)
return self._delete_project(project_name)
delete_role(self, role_name)
Deletes a role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to delete. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def delete_role(self, role_name: str) -> None:
"""Deletes a role.
Args:
role_name: Name of the role to delete.
Raises:
KeyError: If no role with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_ROLE)
return self._delete_role(role_name)
delete_team(self, team_name)
Deletes a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team to delete. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def delete_team(self, team_name: str) -> None:
"""Deletes a team.
Args:
team_name: Name of the team to delete.
Raises:
KeyError: If no team with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_TEAM)
return self._delete_team(team_name)
delete_user(self, user_name)
Deletes a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user to delete. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def delete_user(self, user_name: str) -> None:
"""Deletes a user.
Args:
user_name: Name of the user to delete.
Raises:
KeyError: If no user with the given name exists.
"""
self._track_event(AnalyticsEvent.DELETED_USER)
return self._delete_user(user_name)
deregister_stack(self, name)
Delete a stack from storage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to be deleted. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/base_zen_store.py
def deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
Raises:
KeyError: If no stack exists for the given name.
"""
# No tracking events, here for consistency
return self._deregister_stack(name)
deregister_stack_component(self, component_type, name)
Deregisters a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the component to deregister. |
required |
name |
str |
The name of the component to deregister. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if trying to deregister a component that's part of a stack. |
Source code in zenml/zen_stores/base_zen_store.py
def deregister_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Deregisters a stack component.
Args:
component_type: The type of the component to deregister.
name: The name of the component to deregister.
Raises:
ValueError: if trying to deregister a component that's part
of a stack.
"""
for stack_name, stack_config in self.stack_configurations.items():
if stack_config.get(component_type) == name:
raise ValueError(
f"Unable to deregister stack component (type: "
f"{component_type}, name: {name}) that is part of a "
f"registered stack (stack name: '{stack_name}')."
)
self._delete_stack_component(component_type, name=name)
get_flavor_by_name_and_type(self, flavor_name, component_type)
Fetch a flavor by a given name and type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_name |
str |
The name of the flavor. |
required |
component_type |
StackComponentType |
Optional, the type of the component. |
required |
Returns:
Type | Description |
---|---|
FlavorWrapper |
Flavor instance if it exists |
Exceptions:
Type | Description |
---|---|
KeyError |
If no flavor exists with the given name and type or there are more than one instances |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
get_flavors_by_type(self, component_type)
Fetch all flavor defined for a specific stack component type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the stack component. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of all the flavors for the given stack component type. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
get_local_url(path)
staticmethod
Get a local URL for a given local path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
the path string to build a URL out of. |
required |
Returns:
Type | Description |
---|---|
str |
Url pointing to the path for the store type. |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
@abstractmethod
def get_local_url(path: str) -> str:
"""Get a local URL for a given local path.
Args:
path: the path string to build a URL out of.
Returns:
Url pointing to the path for the store type.
"""
get_path_from_url(url)
staticmethod
Get the path from a URL, if it points or is backed by a local file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL to get the path from. |
required |
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
The local path backed by the URL, or None if the URL is not backed by a local file or directory |
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
@abstractmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the path from a URL, if it points or is backed by a local file.
Args:
url: The URL to get the path from.
Returns:
The local path backed by the URL, or None if the URL is not backed
by a local file or directory
"""
get_pipeline_run(self, pipeline_name, run_name, project_name=None)
Gets a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get the run. |
required |
run_name |
str |
Name of the pipeline run to get. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline run. |
None |
Exceptions:
Type | Description |
---|---|
KeyError |
If no pipeline run (or project) with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
get_pipeline_runs(self, pipeline_name, project_name=None)
Gets pipeline runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get runs. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline runs. |
None |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
get_project(self, project_name)
Gets a specific project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project_name |
str |
Name of the project to get. |
required |
Returns:
Type | Description |
---|---|
Project |
The requested project. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no project with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def get_project(self, project_name: str) -> Project:
"""Gets a specific project.
Args:
project_name: Name of the project to get.
Returns:
The requested project.
Raises:
KeyError: If no project with the given name exists.
"""
# No tracking events, here for consistency
return self._get_project(project_name)
get_role(self, role_name)
Gets a specific role.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to get. |
required |
Returns:
Type | Description |
---|---|
Role |
The requested role. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def get_role(self, role_name: str) -> Role:
"""Gets a specific role.
Args:
role_name: Name of the role to get.
Returns:
The requested role.
Raises:
KeyError: If no role with the given name exists.
"""
# No tracking events, here for consistency
return self._get_role(role_name)
get_role_assignments_for_team(self, team_name, project_name=None)
Fetches all role assignments for a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team or project with the given names exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no team or project with the given names exists.
"""
get_role_assignments_for_user(self, user_name, project_name=None, include_team_roles=True)
Fetches all role assignments for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
include_team_roles |
bool |
If |
True |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this user. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user or project with the given names exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
"""
get_stack(self, name)
Fetch a stack by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to retrieve. |
required |
Returns:
Type | Description |
---|---|
StackWrapper |
StackWrapper instance if the stack exists. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/base_zen_store.py
def get_stack(self, name: str) -> StackWrapper:
"""Fetch a stack by name.
Args:
name: The name of the stack to retrieve.
Returns:
StackWrapper instance if the stack exists.
Raises:
KeyError: If no stack exists for the given name.
"""
return self._stack_from_dict(name, self.get_stack_configuration(name))
get_stack_component(self, component_type, name)
Get a registered stack component.
Exceptions:
Type | Description |
---|---|
KeyError |
If no component with the requested type and name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def get_stack_component(
self, component_type: StackComponentType, name: str
) -> ComponentWrapper:
"""Get a registered stack component.
Raises:
KeyError: If no component with the requested type and name exists.
"""
flavor, config = self._get_component_flavor_and_config(
component_type, name=name
)
uuid = yaml.safe_load(base64.b64decode(config).decode())["uuid"]
return ComponentWrapper(
type=component_type,
flavor=flavor,
name=name,
uuid=uuid,
config=config,
)
get_stack_components(self, component_type)
Fetches all registered stack components of the given type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
StackComponentType to list members of |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.component_wrapper.ComponentWrapper] |
A list of StackComponentConfiguration instances. |
Source code in zenml/zen_stores/base_zen_store.py
def get_stack_components(
self, component_type: StackComponentType
) -> List[ComponentWrapper]:
"""Fetches all registered stack components of the given type.
Args:
component_type: StackComponentType to list members of
Returns:
A list of StackComponentConfiguration instances.
"""
return [
self.get_stack_component(component_type=component_type, name=name)
for name in self._get_stack_component_names(component_type)
]
get_stack_configuration(self, name)
Fetches a stack configuration by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to fetch. |
required |
Returns:
Type | Description |
---|---|
Dict[zenml.enums.StackComponentType, str] |
Dict[StackComponentType, str] for the requested stack name. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
get_team(self, team_name)
Gets a specific team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team to get. |
required |
Returns:
Type | Description |
---|---|
Team |
The requested team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def get_team(self, team_name: str) -> Team:
"""Gets a specific team.
Args:
team_name: Name of the team to get.
Returns:
The requested team.
Raises:
KeyError: If no team with the given name exists.
"""
# No tracking events, here for consistency
return self._get_team(team_name)
get_teams_for_user(self, user_name)
Fetches all teams for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
List of teams that the user is part of. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
"""
get_user(self, user_name)
Gets a specific user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user to get. |
required |
Returns:
Type | Description |
---|---|
User |
The requested user. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
def get_user(self, user_name: str) -> User:
"""Gets a specific user.
Args:
user_name: Name of the user to get.
Returns:
The requested user.
Raises:
KeyError: If no user with the given name exists.
"""
# No tracking events, here for consistency
return self._get_user(user_name)
get_users_for_team(self, team_name)
Fetches all users of a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
List of users that are part of the team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team with the given name exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
"""
initialize(self, url, skip_default_registrations=False, track_analytics=True, skip_migration=False, *args, **kwargs)
Initialize the store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL of the store. |
required |
skip_default_registrations |
bool |
If |
False |
track_analytics |
bool |
Only send analytics if set to |
True |
skip_migration |
bool |
If |
False |
*args |
Any |
Additional arguments to pass to the concrete store implementation. |
() |
**kwargs |
Any |
Additional keyword arguments to pass to the concrete store implementation. |
{} |
Returns:
Type | Description |
---|---|
BaseZenStore |
The initialized concrete store instance. |
Source code in zenml/zen_stores/base_zen_store.py
def initialize(
self,
url: str,
skip_default_registrations: bool = False,
track_analytics: bool = True,
skip_migration: bool = False,
*args: Any,
**kwargs: Any,
) -> "BaseZenStore":
"""Initialize the store.
Args:
url: The URL of the store.
skip_default_registrations: If `True`, the creation of the default
stack and user will be skipped.
track_analytics: Only send analytics if set to `True`.
skip_migration: If `True`, no migration will be performed.
*args: Additional arguments to pass to the concrete store
implementation.
**kwargs: Additional keyword arguments to pass to the concrete
store implementation.
Returns:
The initialized concrete store instance.
"""
self._track_analytics = track_analytics
if not skip_default_registrations:
if self.stacks_empty:
logger.info("Registering default stack...")
self.register_default_stack()
self.create_default_user()
if not skip_migration:
self._migrate_store()
return self
is_valid_url(url)
staticmethod
Check if the given url is valid.
Source code in zenml/zen_stores/base_zen_store.py
@staticmethod
@abstractmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is valid."""
register_default_stack(self)
Populates the store with the default Stack.
The default stack contains a local orchestrator, a local artifact store and a local SQLite metadata store.
Source code in zenml/zen_stores/base_zen_store.py
def register_default_stack(self) -> None:
"""Populates the store with the default Stack.
The default stack contains a local orchestrator,
a local artifact store and a local SQLite metadata store.
"""
stack = Stack.default_local_stack()
sw = StackWrapper.from_stack(stack)
self._register_stack(sw)
metadata = {c.type.value: c.flavor for c in sw.components}
metadata["store_type"] = self.type.value
self._track_event(
AnalyticsEvent.REGISTERED_DEFAULT_STACK, metadata=metadata
)
register_pipeline_run(self, pipeline_run)
Registers a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunWrapper |
The pipeline run to register. |
required |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a pipeline run with the same name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
register_stack(self, stack)
Register a stack and its components.
If any of the stack's components aren't registered in the zen store yet, this method will try to register them as well.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack |
StackWrapper |
The stack to register. |
required |
Exceptions:
Type | Description |
---|---|
StackExistsError |
If a stack with the same name already exists. |
StackComponentExistsError |
If a component of the stack wasn't registered and a different component with the same name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def register_stack(self, stack: StackWrapper) -> None:
"""Register a stack and its components.
If any of the stack's components aren't registered in the zen store
yet, this method will try to register them as well.
Args:
stack: The stack to register.
Raises:
StackExistsError: If a stack with the same name already exists.
StackComponentExistsError: If a component of the stack wasn't
registered and a different component with the same name
already exists.
"""
metadata = {c.type.value: c.flavor for c in stack.components}
metadata["store_type"] = self.type.value
track_event(AnalyticsEvent.REGISTERED_STACK, metadata=metadata)
return self._register_stack(stack)
register_stack_component(self, component)
Register a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component |
ComponentWrapper |
The component to register. |
required |
Exceptions:
Type | Description |
---|---|
StackComponentExistsError |
If a stack component with the same type and name already exists. |
Source code in zenml/zen_stores/base_zen_store.py
def register_stack_component(
self,
component: ComponentWrapper,
) -> None:
"""Register a stack component.
Args:
component: The component to register.
Raises:
StackComponentExistsError: If a stack component with the same type
and name already exists.
"""
analytics_metadata = {
"type": component.type.value,
"flavor": component.flavor,
}
self._track_event(
AnalyticsEvent.REGISTERED_STACK_COMPONENT,
metadata=analytics_metadata,
)
return self._register_stack_component(component)
remove_user_from_team(self, team_name, user_name)
Removes a user from a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
revoke_role(self, role_name, entity_name, project_name=None, is_user=True)
Revokes a role from a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to revoke. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/base_zen_store.py
@abstractmethod
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
update_stack(self, name, stack)
Update a stack and its components.
If any of the stack's components aren't registered in the stack store yet, this method will try to register them as well.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The original name of the stack. |
required |
stack |
StackWrapper |
The new stack to use in the update. |
required |
Exceptions:
Type | Description |
---|---|
DoesNotExistException |
If no stack exists with the given name. |
Source code in zenml/zen_stores/base_zen_store.py
def update_stack(self, name: str, stack: StackWrapper) -> None:
"""Update a stack and its components.
If any of the stack's components aren't registered in the stack store
yet, this method will try to register them as well.
Args:
name: The original name of the stack.
stack: The new stack to use in the update.
Raises:
DoesNotExistException: If no stack exists with the given name.
"""
metadata = {c.type.value: c.flavor for c in stack.components}
metadata["store_type"] = self.type.value
track_event(AnalyticsEvent.UPDATED_STACK, metadata=metadata)
return self._update_stack(name, stack)
update_stack_component(self, name, component_type, component)
Update a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The original name of the stack component. |
required |
component_type |
StackComponentType |
The type of the stack component to update. |
required |
component |
ComponentWrapper |
The new component to update with. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack component exists with the given name. |
Source code in zenml/zen_stores/base_zen_store.py
def update_stack_component(
self,
name: str,
component_type: StackComponentType,
component: ComponentWrapper,
) -> Dict[str, str]:
"""Update a stack component.
Args:
name: The original name of the stack component.
component_type: The type of the stack component to update.
component: The new component to update with.
Raises:
KeyError: If no stack component exists with the given name.
"""
analytics_metadata = {
"type": component.type.value,
"flavor": component.flavor,
}
self._track_event(
AnalyticsEvent.UPDATED_STACK_COMPONENT,
metadata=analytics_metadata,
)
return self._update_stack_component(name, component_type, component)
local_zen_store
LocalZenStore (BaseZenStore)
Source code in zenml/zen_stores/local_zen_store.py
class LocalZenStore(BaseZenStore):
def initialize(
self,
url: str,
*args: Any,
store_data: Optional[ZenStoreModel] = None,
**kwargs: Any,
) -> "LocalZenStore":
"""Initializes a local ZenStore instance.
Args:
url: URL of local directory of the repository to use for
storage.
store_data: optional store data object to pre-populate the
zen store with.
args: additional positional arguments (ignored).
kwargs: additional keyword arguments (ignored).
Returns:
The initialized ZenStore instance.
"""
if not self.is_valid_url(url):
raise ValueError(f"Invalid URL for local store: {url}")
self._root = self.get_path_from_url(url)
self._url = f"file://{self._root}"
utils.create_dir_recursive_if_not_exists(str(self._root))
if store_data is not None:
self.__store = store_data
else:
self.__store = ZenStoreModel(str(self.root / "stacks.yaml"))
self.__pipeline_store = ZenStorePipelineModel(
str(self.root / "pipeline_runs.yaml")
)
super().initialize(url, *args, **kwargs)
return self
# Public interface implementations:
@property
def type(self) -> StoreType:
"""The type of zen store."""
return StoreType.LOCAL
@property
def url(self) -> str:
"""URL of the repository."""
return self._url
# Static methods:
@staticmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the path from a URL.
Args:
url: The URL to get the path from.
Returns:
The path from the URL.
"""
if not LocalZenStore.is_valid_url(url):
raise ValueError(f"Invalid URL for local store: {url}")
url = url.replace("file://", "")
return Path(url)
@staticmethod
def get_local_url(path: str) -> str:
"""Get a local URL for a given local path."""
return f"file://{path}"
@staticmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is a valid local path."""
scheme = re.search("^([a-z0-9]+://)", url)
return not scheme or scheme.group() == "file://"
@property
def stacks_empty(self) -> bool:
"""Check if the zen store is empty."""
return len(self.__store.stacks) == 0
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
logger.debug("Fetching stack with name '%s'.", name)
if name not in self.__store.stacks:
raise KeyError(
f"Unable to find stack with name '{name}'. Available names: "
f"{set(self.__store.stacks)}."
)
return self.__store.stacks[name]
@property
def stack_configurations(self) -> Dict[str, Dict[StackComponentType, str]]:
"""Configuration for all stacks registered in this zen store.
Returns:
Dictionary mapping stack names to Dict[StackComponentType, str]
"""
return self.__store.stacks.copy()
def _register_stack_component(
self,
component: ComponentWrapper,
) -> None:
"""Register a stack component.
Args:
component: The component to register.
Raises:
StackComponentExistsError: If a stack component with the same type
and name already exists.
"""
components = self.__store.stack_components[component.type]
if component.name in components:
raise StackComponentExistsError(
f"Unable to register stack component (type: {component.type}) "
f"with name '{component.name}': Found existing stack component "
f"with this name."
)
# write the component configuration file
component_config_path = self._get_stack_component_config_path(
component_type=component.type, name=component.name
)
utils.create_dir_recursive_if_not_exists(
os.path.dirname(component_config_path)
)
utils.write_file_contents_as_string(
component_config_path,
base64.b64decode(component.config).decode(),
)
# add the component to the zen store dict and write it to disk
components[component.name] = component.flavor
self.__store.write_config()
logger.info(
"Registered stack component with type '%s' and name '%s'.",
component.type,
component.name,
)
def _update_stack_component(
self,
name: str,
component_type: StackComponentType,
component: ComponentWrapper,
) -> Dict[str, str]:
"""Update a stack component.
Args:
name: The original name of the stack component.
component_type: The type of the stack component to update.
component: The new component to update with.
Raises:
KeyError: If no stack component exists with the given name.
"""
components = self.__store.stack_components[component_type]
if name not in components:
raise KeyError(
f"Unable to update stack component (type: {component_type}) "
f"with name '{name}': No existing stack component "
f"found with this name."
)
elif name != component.name and component.name in components:
raise StackComponentExistsError(
f"Unable to update stack component (type: {component_type}) "
f"with name '{component.name}': a stack component already "
f"is registered with this name."
)
component_config_path = self._get_stack_component_config_path(
component_type=component.type, name=component.name
)
utils.create_dir_recursive_if_not_exists(
os.path.dirname(component_config_path)
)
utils.write_file_contents_as_string(
component_config_path,
base64.b64decode(component.config).decode(),
)
if name != component.name:
self._delete_stack_component(component_type, name)
# add the component to the stack store dict and write it to disk
components[component.name] = component.flavor
for _, conf in self.stack_configurations.items():
for component_type, component_name in conf.items():
if component_name == name and component_type == component.type:
conf[component_type] = component.name
self.__store.write_config()
logger.info(
"Updated stack component with type '%s' and name '%s'.",
component_type,
component.name,
)
return {component.type.value: component.flavor}
def _deregister_stack(self, name: str) -> None:
"""Remove a stack from storage.
Args:
name: The name of the stack to be deleted.
Raises:
KeyError: If no stack exists for the given name.
"""
del self.__store.stacks[name]
self.__store.write_config()
# Private interface implementations:
def _save_stack(
self,
name: str,
stack_configuration: Dict[StackComponentType, str],
) -> None:
"""Save a stack.
Args:
name: The name to save the stack as.
stack_configuration: Dict[StackComponentType, str] to persist.
"""
self.__store.stacks[name] = stack_configuration
self.__store.write_config()
def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
) -> Tuple[str, bytes]:
"""Fetch the flavor and configuration for a stack component.
Args:
component_type: The type of the component to fetch.
name: The name of the component to fetch.
Returns:
Pair of (flavor, configuration) for stack component, as string and
base64-encoded yaml document, respectively
Raises:
KeyError: If no stack component exists for the given type and name.
"""
components: Dict[str, str] = self.__store.stack_components[
component_type
]
if name not in components:
raise KeyError(
f"Unable to find stack component (type: {component_type}) "
f"with name '{name}'. Available names: {set(components)}."
)
component_config_path = self._get_stack_component_config_path(
component_type=component_type, name=name
)
flavor = components[name]
config = base64.b64encode(
utils.read_file_contents_as_string(component_config_path).encode()
)
return flavor, config
def _get_stack_component_names(
self, component_type: StackComponentType
) -> List[str]:
"""Get names of all registered stack components of a given type."""
return list(self.__store.stack_components[component_type])
def _delete_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Remove a StackComponent from storage.
Args:
component_type: The type of component to delete.
name: Then name of the component to delete.
Raises:
KeyError: If no component exists for given type and name.
"""
component_config_path = self._get_stack_component_config_path(
component_type=component_type, name=name
)
if fileio.exists(component_config_path):
fileio.remove(component_config_path)
components = self.__store.stack_components[component_type]
del components[name]
self.__store.write_config()
# User, project and role management
@property
def users(self) -> List[User]:
"""All registered users.
Returns:
A list of all registered users.
"""
return self.__store.users
def _get_user(self, user_name: str) -> User:
"""Get a specific user by name.
Args:
user_name: Name of the user to get.
Returns:
The requested user, if it was found.
Raises:
KeyError: If no user with the given name exists.
"""
return _get_unique_entity(user_name, collection=self.__store.users)
def _create_user(self, user_name: str) -> User:
"""Creates a new user.
Args:
user_name: Unique username.
Returns:
The newly created user.
Raises:
EntityExistsError: If a user with the given name already exists.
"""
if _get_unique_entity(
user_name, collection=self.__store.users, ensure_exists=False
):
raise EntityExistsError(
f"User with name '{user_name}' already exists."
)
user = User(name=user_name)
self.__store.users.append(user)
self.__store.write_config()
return user
def _delete_user(self, user_name: str) -> None:
"""Deletes a user.
Args:
user_name: Name of the user to delete.
Raises:
KeyError: If no user with the given name exists.
"""
user = _get_unique_entity(user_name, collection=self.__store.users)
self.__store.users.remove(user)
for user_names in self.__store.team_assignments.values():
user_names.discard(user.name)
self.__store.role_assignments = [
assignment
for assignment in self.__store.role_assignments
if assignment.user_id != user.id
]
self.__store.write_config()
logger.info("Deleted user %s.", user)
@property
def teams(self) -> List[Team]:
"""All registered teams.
Returns:
A list of all registered teams.
"""
return self.__store.teams
def _get_team(self, team_name: str) -> Team:
"""Gets a specific team.
Args:
team_name: Name of the team to get.
Returns:
The requested team.
Raises:
KeyError: If no team with the given name exists.
"""
return _get_unique_entity(team_name, collection=self.__store.teams)
def _create_team(self, team_name: str) -> Team:
"""Creates a new team.
Args:
team_name: Unique team name.
Returns:
The newly created team.
Raises:
EntityExistsError: If a team with the given name already exists.
"""
if _get_unique_entity(
team_name, collection=self.__store.teams, ensure_exists=False
):
raise EntityExistsError(
f"Team with name '{team_name}' already exists."
)
team = Team(name=team_name)
self.__store.teams.append(team)
self.__store.write_config()
return team
def _delete_team(self, team_name: str) -> None:
"""Deletes a team.
Args:
team_name: Name of the team to delete.
Raises:
KeyError: If no team with the given name exists.
"""
team = _get_unique_entity(team_name, collection=self.__store.teams)
self.__store.teams.remove(team)
self.__store.team_assignments.pop(team.name, None)
self.__store.role_assignments = [
assignment
for assignment in self.__store.role_assignments
if assignment.team_id != team.id
]
self.__store.write_config()
logger.info("Deleted team %s.", team)
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
team = _get_unique_entity(team_name, self.__store.teams)
user = _get_unique_entity(user_name, self.__store.users)
self.__store.team_assignments[team.name].add(user.name)
self.__store.write_config()
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
team = _get_unique_entity(team_name, self.__store.teams)
user = _get_unique_entity(user_name, self.__store.users)
self.__store.team_assignments[team.name].remove(user.name)
self.__store.write_config()
@property
def projects(self) -> List[Project]:
"""All registered projects.
Returns:
A list of all registered projects.
"""
return self.__store.projects
def _get_project(self, project_name: str) -> Project:
"""Get an existing project by name.
Args:
project_name: Name of the project to get.
Returns:
The requested project if one was found.
Raises:
KeyError: If there is no such project.
"""
return _get_unique_entity(
project_name, collection=self.__store.projects
)
def _create_project(
self, project_name: str, description: Optional[str] = None
) -> Project:
"""Creates a new project.
Args:
project_name: Unique project name.
description: Optional project description.
Returns:
The newly created project.
Raises:
EntityExistsError: If a project with the given name already exists.
"""
if _get_unique_entity(
project_name, collection=self.__store.projects, ensure_exists=False
):
raise EntityExistsError(
f"Project with name '{project_name}' already exists."
)
project = Project(name=project_name, description=description)
self.__store.projects.append(project)
self.__store.write_config()
return project
def _delete_project(self, project_name: str) -> None:
"""Deletes a project.
Args:
project_name: Name of the project to delete.
Raises:
KeyError: If no project with the given name exists.
"""
project = _get_unique_entity(
project_name, collection=self.__store.projects
)
self.__store.projects.remove(project)
self.__store.role_assignments = [
assignment
for assignment in self.__store.role_assignments
if assignment.project_id != project.id
]
self.__store.write_config()
logger.info("Deleted project %s.", project)
@property
def roles(self) -> List[Role]:
"""All registered roles.
Returns:
A list of all registered roles.
"""
return self.__store.roles
@property
def role_assignments(self) -> List[RoleAssignment]:
"""All registered role assignments.
Returns:
A list of all registered role assignments.
"""
return self.__store.role_assignments
def _get_role(self, role_name: str) -> Role:
"""Gets a specific role.
Args:
role_name: Name of the role to get.
Returns:
The requested role.
Raises:
KeyError: If no role with the given name exists.
"""
return _get_unique_entity(role_name, collection=self.__store.roles)
def _create_role(self, role_name: str) -> Role:
"""Creates a new role.
Args:
role_name: Unique role name.
Returns:
The newly created role.
Raises:
EntityExistsError: If a role with the given name already exists.
"""
if _get_unique_entity(
role_name, collection=self.__store.roles, ensure_exists=False
):
raise EntityExistsError(
f"Role with name '{role_name}' already exists."
)
role = Role(name=role_name)
self.__store.roles.append(role)
self.__store.write_config()
return role
def _delete_role(self, role_name: str) -> None:
"""Deletes a role.
Args:
role_name: Name of the role to delete.
Raises:
KeyError: If no role with the given name exists.
"""
role = _get_unique_entity(role_name, collection=self.__store.roles)
self.__store.roles.remove(role)
self.__store.role_assignments = [
assignment
for assignment in self.__store.role_assignments
if assignment.role_id != role.id
]
self.__store.write_config()
logger.info("Deleted role %s.", role)
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
role = _get_unique_entity(role_name, collection=self.__store.roles)
project_id: Optional[UUID] = None
if project_name:
project_id = _get_unique_entity(
project_name, collection=self.__store.projects
).id
if is_user:
user = _get_unique_entity(entity_name, self.__store.users)
assignment = RoleAssignment(
role_id=role.id, project_id=project_id, user_id=user.id
)
else:
team = _get_unique_entity(entity_name, self.__store.teams)
assignment = RoleAssignment(
role_id=role.id, project_id=project_id, team_id=team.id
)
self.__store.role_assignments.append(assignment)
self.__store.write_config()
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
role = _get_unique_entity(role_name, collection=self.__store.roles)
user_id: Optional[UUID] = None
team_id: Optional[UUID] = None
project_id: Optional[UUID] = None
if is_user:
user_id = _get_unique_entity(entity_name, self.__store.users).id
else:
team_id = _get_unique_entity(entity_name, self.__store.teams).id
if project_name:
project_id = _get_unique_entity(
project_name, collection=self.__store.projects
).id
assignments = self._get_role_assignments(
role_id=role.id,
user_id=user_id,
team_id=team_id,
project_id=project_id,
)
if assignments:
self.__store.role_assignments.remove(
assignments[0]
) # there should only be one
self.__store.write_config()
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
"""
team = _get_unique_entity(team_name, collection=self.__store.teams)
user_names = self.__store.team_assignments[team.name]
return [user for user in self.users if user.name in user_names]
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
"""
user = _get_unique_entity(user_name, collection=self.__store.users)
team_names = [
team_name
for team_name, user_names in self.__store.team_assignments.items()
if user.name in user_names
]
return [team for team in self.teams if team.name in team_names]
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
"""
user = _get_unique_entity(user_name, collection=self.__store.users)
project_id = (
_get_unique_entity(
project_name, collection=self.__store.projects
).id
if project_name
else None
)
assignments = self._get_role_assignments(
user_id=user.id, project_id=project_id
)
if include_team_roles:
for team in self.get_teams_for_user(user_name):
assignments += self.get_role_assignments_for_team(
team.name, project_name=project_name
)
return assignments
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no team or project with the given names exists.
"""
team = _get_unique_entity(team_name, collection=self.__store.teams)
project_id = (
_get_unique_entity(
project_name, collection=self.__store.projects
).id
if project_name
else None
)
return self._get_role_assignments(
team_id=team.id, project_id=project_id
)
# Pipelines and pipeline runs
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
runs = self.__pipeline_store.pipeline_runs[pipeline_name]
for run in runs:
if run.name != run_name:
continue
if project_name and run.project_name != project_name:
continue
return run
project_message = (
f" in project {project_name}." if project_name else "."
)
raise KeyError(
f"No pipeline run '{run_name}' found for pipeline "
f"'{pipeline_name}'{project_message}"
)
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
runs = self.__pipeline_store.pipeline_runs[pipeline_name]
if project_name:
runs = [run for run in runs if run.project_name == project_name]
return runs
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
all_runs = list(
itertools.chain.from_iterable(
self.__pipeline_store.pipeline_runs.values()
)
)
if _get_unique_entity(
entity_name=pipeline_run.name,
collection=all_runs,
ensure_exists=False,
):
raise EntityExistsError(
f"Pipeline run with name '{pipeline_run.name}' already exists. "
"Please make sure your pipeline run names are unique."
)
self.__pipeline_store.pipeline_runs[pipeline_run.pipeline.name].append(
pipeline_run
)
self.__pipeline_store.write_config()
# Handling stack component flavors
@property
def flavors(self) -> List[FlavorWrapper]:
"""All registered flavors.
Returns:
A list of all registered flavors.
"""
return self.__store.stack_component_flavors
def _create_flavor(
self,
source: str,
name: str,
stack_component_type: StackComponentType,
) -> FlavorWrapper:
"""Creates a new flavor.
Args:
source: the source path to the implemented flavor.
name: the name of the flavor.
stack_component_type: the corresponding StackComponentType.
Returns:
The newly created flavor.
Raises:
EntityExistsError: If a flavor with the given name and type
already exists.
"""
if _get_unique_entity(
name,
collection=self.get_flavors_by_type(stack_component_type),
ensure_exists=False,
):
raise EntityExistsError(
f"The flavor '{name}' for the stack component type "
f"'{stack_component_type.plural}' already exists."
)
flavor = FlavorWrapper(
name=name,
source=source,
type=stack_component_type,
)
self.__store.stack_component_flavors.append(flavor)
self.__store.write_config()
return flavor
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
return [
f
for f in self.__store.stack_component_flavors
if f.type == component_type
]
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
matches = self.get_flavors_by_type(component_type)
return _get_unique_entity(
entity_name=flavor_name,
collection=matches,
ensure_exists=True,
)
# Implementation-specific internal methods:
@property
def root(self) -> Path:
"""The root directory of the zen store."""
if not self._root:
raise RuntimeError(
"Local zen store has not been initialized. Call `initialize` "
"before using the store."
)
return self._root
def _get_stack_component_config_path(
self, component_type: StackComponentType, name: str
) -> str:
"""Path to the configuration file of a stack component."""
path = self.root / component_type.plural / f"{name}.yaml"
return str(path)
def _get_role_assignments(
self,
role_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
user_id: Optional[UUID] = None,
team_id: Optional[UUID] = None,
) -> List[RoleAssignment]:
"""Gets all role assignments that match the criteria.
Args:
role_id: Only include role assignments associated with this role id.
project_id: Only include role assignments associated with this
project id.
user_id: Only include role assignments associated with this user id.
team_id: Only include role assignments associated with this team id.
Returns:
List of role assignments.
"""
return [
assignment
for assignment in self.__store.role_assignments
if not (
(role_id and assignment.role_id != role_id)
or (project_id and project_id != assignment.project_id)
or (user_id and user_id != assignment.user_id)
or (team_id and team_id != assignment.team_id)
)
]
flavors: List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]
property
readonly
All registered flavors.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
A list of all registered flavors. |
projects: List[zenml.zen_stores.models.user_management_models.Project]
property
readonly
All registered projects.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Project] |
A list of all registered projects. |
role_assignments: List[zenml.zen_stores.models.user_management_models.RoleAssignment]
property
readonly
All registered role assignments.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
A list of all registered role assignments. |
roles: List[zenml.zen_stores.models.user_management_models.Role]
property
readonly
All registered roles.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Role] |
A list of all registered roles. |
root: Path
property
readonly
The root directory of the zen store.
stack_configurations: Dict[str, Dict[zenml.enums.StackComponentType, str]]
property
readonly
Configuration for all stacks registered in this zen store.
Returns:
Type | Description |
---|---|
Dict[str, Dict[zenml.enums.StackComponentType, str]] |
Dictionary mapping stack names to Dict[StackComponentType, str] |
stacks_empty: bool
property
readonly
Check if the zen store is empty.
teams: List[zenml.zen_stores.models.user_management_models.Team]
property
readonly
All registered teams.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
A list of all registered teams. |
type: StoreType
property
readonly
The type of zen store.
url: str
property
readonly
URL of the repository.
users: List[zenml.zen_stores.models.user_management_models.User]
property
readonly
All registered users.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
A list of all registered users. |
add_user_to_team(self, team_name, user_name)
Adds a user to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/local_zen_store.py
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
team = _get_unique_entity(team_name, self.__store.teams)
user = _get_unique_entity(user_name, self.__store.users)
self.__store.team_assignments[team.name].add(user.name)
self.__store.write_config()
assign_role(self, role_name, entity_name, project_name=None, is_user=True)
Assigns a role to a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to assign. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/local_zen_store.py
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
role = _get_unique_entity(role_name, collection=self.__store.roles)
project_id: Optional[UUID] = None
if project_name:
project_id = _get_unique_entity(
project_name, collection=self.__store.projects
).id
if is_user:
user = _get_unique_entity(entity_name, self.__store.users)
assignment = RoleAssignment(
role_id=role.id, project_id=project_id, user_id=user.id
)
else:
team = _get_unique_entity(entity_name, self.__store.teams)
assignment = RoleAssignment(
role_id=role.id, project_id=project_id, team_id=team.id
)
self.__store.role_assignments.append(assignment)
self.__store.write_config()
get_flavor_by_name_and_type(self, flavor_name, component_type)
Fetch a flavor by a given name and type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_name |
str |
The name of the flavor. |
required |
component_type |
StackComponentType |
Optional, the type of the component. |
required |
Returns:
Type | Description |
---|---|
FlavorWrapper |
Flavor instance if it exists |
Exceptions:
Type | Description |
---|---|
KeyError |
If no flavor exists with the given name and type or there are more than one instances |
Source code in zenml/zen_stores/local_zen_store.py
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
matches = self.get_flavors_by_type(component_type)
return _get_unique_entity(
entity_name=flavor_name,
collection=matches,
ensure_exists=True,
)
get_flavors_by_type(self, component_type)
Fetch all flavor defined for a specific stack component type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the stack component. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of all the flavors for the given stack component type. |
Source code in zenml/zen_stores/local_zen_store.py
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
return [
f
for f in self.__store.stack_component_flavors
if f.type == component_type
]
get_local_url(path)
staticmethod
Get a local URL for a given local path.
Source code in zenml/zen_stores/local_zen_store.py
@staticmethod
def get_local_url(path: str) -> str:
"""Get a local URL for a given local path."""
return f"file://{path}"
get_path_from_url(url)
staticmethod
Get the path from a URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL to get the path from. |
required |
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
The path from the URL. |
Source code in zenml/zen_stores/local_zen_store.py
@staticmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the path from a URL.
Args:
url: The URL to get the path from.
Returns:
The path from the URL.
"""
if not LocalZenStore.is_valid_url(url):
raise ValueError(f"Invalid URL for local store: {url}")
url = url.replace("file://", "")
return Path(url)
get_pipeline_run(self, pipeline_name, run_name, project_name=None)
Gets a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get the run. |
required |
run_name |
str |
Name of the pipeline run to get. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline run. |
None |
Exceptions:
Type | Description |
---|---|
KeyError |
If no pipeline run (or project) with the given name exists. |
Source code in zenml/zen_stores/local_zen_store.py
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
runs = self.__pipeline_store.pipeline_runs[pipeline_name]
for run in runs:
if run.name != run_name:
continue
if project_name and run.project_name != project_name:
continue
return run
project_message = (
f" in project {project_name}." if project_name else "."
)
raise KeyError(
f"No pipeline run '{run_name}' found for pipeline "
f"'{pipeline_name}'{project_message}"
)
get_pipeline_runs(self, pipeline_name, project_name=None)
Gets pipeline runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get runs. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline runs. |
None |
Source code in zenml/zen_stores/local_zen_store.py
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
runs = self.__pipeline_store.pipeline_runs[pipeline_name]
if project_name:
runs = [run for run in runs if run.project_name == project_name]
return runs
get_role_assignments_for_team(self, team_name, project_name=None)
Fetches all role assignments for a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team or project with the given names exists. |
Source code in zenml/zen_stores/local_zen_store.py
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no team or project with the given names exists.
"""
team = _get_unique_entity(team_name, collection=self.__store.teams)
project_id = (
_get_unique_entity(
project_name, collection=self.__store.projects
).id
if project_name
else None
)
return self._get_role_assignments(
team_id=team.id, project_id=project_id
)
get_role_assignments_for_user(self, user_name, project_name=None, include_team_roles=True)
Fetches all role assignments for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
include_team_roles |
bool |
If |
True |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this user. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user or project with the given names exists. |
Source code in zenml/zen_stores/local_zen_store.py
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
"""
user = _get_unique_entity(user_name, collection=self.__store.users)
project_id = (
_get_unique_entity(
project_name, collection=self.__store.projects
).id
if project_name
else None
)
assignments = self._get_role_assignments(
user_id=user.id, project_id=project_id
)
if include_team_roles:
for team in self.get_teams_for_user(user_name):
assignments += self.get_role_assignments_for_team(
team.name, project_name=project_name
)
return assignments
get_stack_configuration(self, name)
Fetches a stack configuration by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to fetch. |
required |
Returns:
Type | Description |
---|---|
Dict[zenml.enums.StackComponentType, str] |
Dict[StackComponentType, str] for the requested stack name. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/local_zen_store.py
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
logger.debug("Fetching stack with name '%s'.", name)
if name not in self.__store.stacks:
raise KeyError(
f"Unable to find stack with name '{name}'. Available names: "
f"{set(self.__store.stacks)}."
)
return self.__store.stacks[name]
get_teams_for_user(self, user_name)
Fetches all teams for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
List of teams that the user is part of. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user with the given name exists. |
Source code in zenml/zen_stores/local_zen_store.py
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
"""
user = _get_unique_entity(user_name, collection=self.__store.users)
team_names = [
team_name
for team_name, user_names in self.__store.team_assignments.items()
if user.name in user_names
]
return [team for team in self.teams if team.name in team_names]
get_users_for_team(self, team_name)
Fetches all users of a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
List of users that are part of the team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team with the given name exists. |
Source code in zenml/zen_stores/local_zen_store.py
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
"""
team = _get_unique_entity(team_name, collection=self.__store.teams)
user_names = self.__store.team_assignments[team.name]
return [user for user in self.users if user.name in user_names]
initialize(self, url, *args, *, store_data=None, **kwargs)
Initializes a local ZenStore instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL of local directory of the repository to use for storage. |
required |
store_data |
Optional[zenml.zen_stores.models.zen_store_model.ZenStoreModel] |
optional store data object to pre-populate the zen store with. |
None |
args |
Any |
additional positional arguments (ignored). |
() |
kwargs |
Any |
additional keyword arguments (ignored). |
{} |
Returns:
Type | Description |
---|---|
LocalZenStore |
The initialized ZenStore instance. |
Source code in zenml/zen_stores/local_zen_store.py
def initialize(
self,
url: str,
*args: Any,
store_data: Optional[ZenStoreModel] = None,
**kwargs: Any,
) -> "LocalZenStore":
"""Initializes a local ZenStore instance.
Args:
url: URL of local directory of the repository to use for
storage.
store_data: optional store data object to pre-populate the
zen store with.
args: additional positional arguments (ignored).
kwargs: additional keyword arguments (ignored).
Returns:
The initialized ZenStore instance.
"""
if not self.is_valid_url(url):
raise ValueError(f"Invalid URL for local store: {url}")
self._root = self.get_path_from_url(url)
self._url = f"file://{self._root}"
utils.create_dir_recursive_if_not_exists(str(self._root))
if store_data is not None:
self.__store = store_data
else:
self.__store = ZenStoreModel(str(self.root / "stacks.yaml"))
self.__pipeline_store = ZenStorePipelineModel(
str(self.root / "pipeline_runs.yaml")
)
super().initialize(url, *args, **kwargs)
return self
is_valid_url(url)
staticmethod
Check if the given url is a valid local path.
Source code in zenml/zen_stores/local_zen_store.py
@staticmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is a valid local path."""
scheme = re.search("^([a-z0-9]+://)", url)
return not scheme or scheme.group() == "file://"
register_pipeline_run(self, pipeline_run)
Registers a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunWrapper |
The pipeline run to register. |
required |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a pipeline run with the same name already exists. |
Source code in zenml/zen_stores/local_zen_store.py
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
all_runs = list(
itertools.chain.from_iterable(
self.__pipeline_store.pipeline_runs.values()
)
)
if _get_unique_entity(
entity_name=pipeline_run.name,
collection=all_runs,
ensure_exists=False,
):
raise EntityExistsError(
f"Pipeline run with name '{pipeline_run.name}' already exists. "
"Please make sure your pipeline run names are unique."
)
self.__pipeline_store.pipeline_runs[pipeline_run.pipeline.name].append(
pipeline_run
)
self.__pipeline_store.write_config()
remove_user_from_team(self, team_name, user_name)
Removes a user from a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/local_zen_store.py
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
team = _get_unique_entity(team_name, self.__store.teams)
user = _get_unique_entity(user_name, self.__store.users)
self.__store.team_assignments[team.name].remove(user.name)
self.__store.write_config()
revoke_role(self, role_name, entity_name, project_name=None, is_user=True)
Revokes a role from a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to revoke. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/local_zen_store.py
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
role = _get_unique_entity(role_name, collection=self.__store.roles)
user_id: Optional[UUID] = None
team_id: Optional[UUID] = None
project_id: Optional[UUID] = None
if is_user:
user_id = _get_unique_entity(entity_name, self.__store.users).id
else:
team_id = _get_unique_entity(entity_name, self.__store.teams).id
if project_name:
project_id = _get_unique_entity(
project_name, collection=self.__store.projects
).id
assignments = self._get_role_assignments(
role_id=role.id,
user_id=user_id,
team_id=team_id,
project_id=project_id,
)
if assignments:
self.__store.role_assignments.remove(
assignments[0]
) # there should only be one
self.__store.write_config()
models
special
component_wrapper
ComponentWrapper (BaseModel)
pydantic-model
Serializable Configuration of a StackComponent
Source code in zenml/zen_stores/models/component_wrapper.py
class ComponentWrapper(BaseModel):
"""Serializable Configuration of a StackComponent"""
type: StackComponentType
flavor: str
name: str
uuid: UUID
config: bytes # b64 encoded yaml config
@classmethod
def from_component(cls, component: "StackComponent") -> "ComponentWrapper":
"""Creates a ComponentWrapper from an actual instance of a Stack
Component.
Args:
component: the instance of a StackComponent
"""
return cls(
type=component.TYPE,
flavor=component.FLAVOR,
name=component.name,
uuid=component.uuid,
config=base64.b64encode(
yaml.dump(json.loads(component.json())).encode()
),
)
def to_component(self) -> "StackComponent":
"""Converts the ComponentWrapper into an actual instance of a Stack
Component."""
from zenml.repository import Repository
flavor = Repository(skip_repository_check=True).get_flavor( # type: ignore[call-arg]
name=self.flavor, component_type=self.type
)
config = yaml.safe_load(base64.b64decode(self.config).decode())
return flavor.parse_obj(config)
from_component(component)
classmethod
Creates a ComponentWrapper from an actual instance of a Stack Component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component |
StackComponent |
the instance of a StackComponent |
required |
Source code in zenml/zen_stores/models/component_wrapper.py
@classmethod
def from_component(cls, component: "StackComponent") -> "ComponentWrapper":
"""Creates a ComponentWrapper from an actual instance of a Stack
Component.
Args:
component: the instance of a StackComponent
"""
return cls(
type=component.TYPE,
flavor=component.FLAVOR,
name=component.name,
uuid=component.uuid,
config=base64.b64encode(
yaml.dump(json.loads(component.json())).encode()
),
)
to_component(self)
Converts the ComponentWrapper into an actual instance of a Stack Component.
Source code in zenml/zen_stores/models/component_wrapper.py
def to_component(self) -> "StackComponent":
"""Converts the ComponentWrapper into an actual instance of a Stack
Component."""
from zenml.repository import Repository
flavor = Repository(skip_repository_check=True).get_flavor( # type: ignore[call-arg]
name=self.flavor, component_type=self.type
)
config = yaml.safe_load(base64.b64decode(self.config).decode())
return flavor.parse_obj(config)
flavor_wrapper
FlavorWrapper (BaseModel)
pydantic-model
Network serializable wrapper representing the custom implementation of a stack component flavor.
Source code in zenml/zen_stores/models/flavor_wrapper.py
class FlavorWrapper(BaseModel):
"""Network serializable wrapper representing the custom implementation of
a stack component flavor."""
name: str
type: StackComponentType
source: str
integration: Optional[str]
@property
def reachable(self) -> bool:
"""Property to which indicates whether ZenML can import the module
within the source."""
from zenml.integrations.registry import integration_registry
if self.integration:
if self.integration == "built-in":
return True
else:
return integration_registry.is_installed(self.integration)
else:
try:
validate_flavor_source(
source=self.source, component_type=self.type
)
return True
except (AssertionError, ModuleNotFoundError, ImportError):
pass
return False
@classmethod
def from_flavor(cls, flavor: Type[StackComponent]) -> "FlavorWrapper":
"""Creates a FlavorWrapper from a flavor class.
Args:
flavor: the class which defines the flavor
"""
return FlavorWrapper(
name=flavor.FLAVOR,
type=flavor.TYPE,
source=flavor.__module__ + "." + flavor.__name__,
)
def to_flavor(self) -> Type[StackComponent]:
"""Imports and returns the class of the flavor."""
try:
return load_source_path_class(source=self.source) # noqa
except (ModuleNotFoundError, ImportError, NotImplementedError):
if self.integration:
raise ImportError(
f"The {self.type} flavor '{self.name}' is "
f"a part of ZenML's '{self.integration}' "
f"integration, which is currently not installed on your "
f"system. You can install it by executing: 'zenml "
f"integration install {self.integration}'."
)
else:
raise ImportError(
f"The {self.type} that you are trying to register has "
f"a custom flavor '{self.name}'. In order to "
f"register it, ZenML needs to be able to import the flavor "
f"through its source which is defined as: "
f"{self.source}. Unfortunately, this is not "
f"possible due to the current set of available modules/"
f"working directory. Please make sure that this execution "
f"is carried out in an environment where this source "
f"is reachable as a module."
)
reachable: bool
property
readonly
Property to which indicates whether ZenML can import the module within the source.
from_flavor(flavor)
classmethod
Creates a FlavorWrapper from a flavor class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
Type[zenml.stack.stack_component.StackComponent] |
the class which defines the flavor |
required |
Source code in zenml/zen_stores/models/flavor_wrapper.py
@classmethod
def from_flavor(cls, flavor: Type[StackComponent]) -> "FlavorWrapper":
"""Creates a FlavorWrapper from a flavor class.
Args:
flavor: the class which defines the flavor
"""
return FlavorWrapper(
name=flavor.FLAVOR,
type=flavor.TYPE,
source=flavor.__module__ + "." + flavor.__name__,
)
to_flavor(self)
Imports and returns the class of the flavor.
Source code in zenml/zen_stores/models/flavor_wrapper.py
def to_flavor(self) -> Type[StackComponent]:
"""Imports and returns the class of the flavor."""
try:
return load_source_path_class(source=self.source) # noqa
except (ModuleNotFoundError, ImportError, NotImplementedError):
if self.integration:
raise ImportError(
f"The {self.type} flavor '{self.name}' is "
f"a part of ZenML's '{self.integration}' "
f"integration, which is currently not installed on your "
f"system. You can install it by executing: 'zenml "
f"integration install {self.integration}'."
)
else:
raise ImportError(
f"The {self.type} that you are trying to register has "
f"a custom flavor '{self.name}'. In order to "
f"register it, ZenML needs to be able to import the flavor "
f"through its source which is defined as: "
f"{self.source}. Unfortunately, this is not "
f"possible due to the current set of available modules/"
f"working directory. Please make sure that this execution "
f"is carried out in an environment where this source "
f"is reachable as a module."
)
pipeline_models
PipelineRunWrapper (BaseModel)
pydantic-model
Pydantic object representing a pipeline run.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
Pipeline run name. |
zenml_version |
str |
Version of ZenML that this pipeline run was performed with. |
git_sha |
Optional[str] |
Git commit SHA that this pipeline run was performed on. This will only be set if the pipeline code is in a git repository and there are no uncommitted files when running the pipeline. |
pipeline |
PipelineWrapper |
Pipeline that this run is referring to. |
stack |
StackWrapper |
Stack that this run was performed on. |
runtime_configuration |
Dict[str, Any] |
Runtime configuration that was used for this run. |
user_id |
UUID |
Id of the user that ran this pipeline. |
project_name |
Optional[str] |
Name of the project that this pipeline was run in. |
Source code in zenml/zen_stores/models/pipeline_models.py
class PipelineRunWrapper(BaseModel):
"""Pydantic object representing a pipeline run.
Attributes:
name: Pipeline run name.
zenml_version: Version of ZenML that this pipeline run was performed
with.
git_sha: Git commit SHA that this pipeline run was performed on. This
will only be set if the pipeline code is in a git repository and
there are no uncommitted files when running the pipeline.
pipeline: Pipeline that this run is referring to.
stack: Stack that this run was performed on.
runtime_configuration: Runtime configuration that was used for this run.
user_id: Id of the user that ran this pipeline.
project_name: Name of the project that this pipeline was run in.
"""
name: str
zenml_version: str = zenml.__version__
git_sha: Optional[str] = Field(default_factory=get_git_sha)
pipeline: PipelineWrapper
stack: StackWrapper
runtime_configuration: Dict[str, Any]
user_id: UUID
project_name: Optional[str]
PipelineWrapper (BaseModel)
pydantic-model
Pydantic object representing a pipeline.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
Pipeline name |
docstring |
Optional[str] |
Docstring of the pipeline |
steps |
List[zenml.zen_stores.models.pipeline_models.StepWrapper] |
List of steps in this pipeline |
Source code in zenml/zen_stores/models/pipeline_models.py
class PipelineWrapper(BaseModel):
"""Pydantic object representing a pipeline.
Attributes:
name: Pipeline name
docstring: Docstring of the pipeline
steps: List of steps in this pipeline
"""
name: str
docstring: Optional[str]
steps: List[StepWrapper]
@classmethod
def from_pipeline(cls, pipeline: "BasePipeline") -> "PipelineWrapper":
"""Creates a PipelineWrapper from a pipeline instance."""
steps = [
StepWrapper.from_step(step) for step in pipeline.steps.values()
]
return cls(
name=pipeline.name,
docstring=pipeline.__doc__,
steps=steps,
)
from_pipeline(pipeline)
classmethod
Creates a PipelineWrapper from a pipeline instance.
Source code in zenml/zen_stores/models/pipeline_models.py
@classmethod
def from_pipeline(cls, pipeline: "BasePipeline") -> "PipelineWrapper":
"""Creates a PipelineWrapper from a pipeline instance."""
steps = [
StepWrapper.from_step(step) for step in pipeline.steps.values()
]
return cls(
name=pipeline.name,
docstring=pipeline.__doc__,
steps=steps,
)
StepWrapper (BaseModel)
pydantic-model
Pydantic object representing a step.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
Step name |
docstring |
Optional[str] |
Docstring of the step |
Source code in zenml/zen_stores/models/pipeline_models.py
class StepWrapper(BaseModel):
"""Pydantic object representing a step.
Attributes:
name: Step name
docstring: Docstring of the step
"""
name: str
docstring: Optional[str]
@classmethod
def from_step(cls, step: "BaseStep") -> "StepWrapper":
"""Creates a StepWrapper from a step instance."""
return cls(
name=step.name,
docstring=step.__doc__,
)
from_step(step)
classmethod
Creates a StepWrapper from a step instance.
Source code in zenml/zen_stores/models/pipeline_models.py
@classmethod
def from_step(cls, step: "BaseStep") -> "StepWrapper":
"""Creates a StepWrapper from a step instance."""
return cls(
name=step.name,
docstring=step.__doc__,
)
get_git_sha(clean=True)
Returns the current git HEAD SHA.
If the current working directory is not inside a git repo, this will return
None
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clean |
bool |
If |
True |
Source code in zenml/zen_stores/models/pipeline_models.py
def get_git_sha(clean: bool = True) -> Optional[str]:
"""Returns the current git HEAD SHA.
If the current working directory is not inside a git repo, this will return
`None`.
Args:
clean: If `True` and there any untracked files or files in the index or
working tree, this function will return `None`.
"""
try:
from git.exc import InvalidGitRepositoryError
from git.repo.base import Repo
except ImportError:
return None
try:
repo = Repo(search_parent_directories=True)
except InvalidGitRepositoryError:
return None
if clean and repo.is_dirty(untracked_files=True):
return None
return cast(str, repo.head.object.hexsha)
stack_wrapper
StackWrapper (BaseModel)
pydantic-model
Network Serializable Wrapper describing a Stack.
Source code in zenml/zen_stores/models/stack_wrapper.py
class StackWrapper(BaseModel):
"""Network Serializable Wrapper describing a Stack."""
name: str
components: List[ComponentWrapper]
@classmethod
def from_stack(cls, stack: Stack) -> "StackWrapper":
"""Creates a StackWrapper from an actual Stack instance.
Args:
stack: the instance of a Stack
"""
return cls(
name=stack.name,
components=[
ComponentWrapper.from_component(component)
for t, component in stack.components.items()
],
)
def to_stack(self) -> Stack:
"""Creates the corresponding Stack instance from the wrapper."""
stack_components = {}
for component_wrapper in self.components:
component_type = component_wrapper.type
component = component_wrapper.to_component()
stack_components[component_type] = component
return Stack.from_components(
name=self.name, components=stack_components
)
def get_component_wrapper(
self, component_type: StackComponentType
) -> Optional[ComponentWrapper]:
"""Returns the component of the given type."""
for component_wrapper in self.components:
if component_wrapper.type == component_type:
return component_wrapper
return None
from_stack(stack)
classmethod
Creates a StackWrapper from an actual Stack instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack |
Stack |
the instance of a Stack |
required |
Source code in zenml/zen_stores/models/stack_wrapper.py
@classmethod
def from_stack(cls, stack: Stack) -> "StackWrapper":
"""Creates a StackWrapper from an actual Stack instance.
Args:
stack: the instance of a Stack
"""
return cls(
name=stack.name,
components=[
ComponentWrapper.from_component(component)
for t, component in stack.components.items()
],
)
get_component_wrapper(self, component_type)
Returns the component of the given type.
Source code in zenml/zen_stores/models/stack_wrapper.py
def get_component_wrapper(
self, component_type: StackComponentType
) -> Optional[ComponentWrapper]:
"""Returns the component of the given type."""
for component_wrapper in self.components:
if component_wrapper.type == component_type:
return component_wrapper
return None
to_stack(self)
Creates the corresponding Stack instance from the wrapper.
Source code in zenml/zen_stores/models/stack_wrapper.py
def to_stack(self) -> Stack:
"""Creates the corresponding Stack instance from the wrapper."""
stack_components = {}
for component_wrapper in self.components:
component_type = component_wrapper.type
component = component_wrapper.to_component()
stack_components[component_type] = component
return Stack.from_components(
name=self.name, components=stack_components
)
user_management_models
Operation (BaseModel)
pydantic-model
Pydantic object representing an operation that requires permission.
Attributes:
Name | Type | Description |
---|---|---|
id |
int |
Operation id. |
name |
str |
Operation name. |
Source code in zenml/zen_stores/models/user_management_models.py
class Operation(BaseModel):
"""Pydantic object representing an operation that requires permission.
Attributes:
id: Operation id.
name: Operation name.
"""
id: int
name: str
Permission (BaseModel)
pydantic-model
Pydantic object representing permissions on a specific resource.
Attributes:
Name | Type | Description |
---|---|---|
operation |
Operation |
The operation for which the permissions are. |
types |
Set[zenml.zen_stores.models.user_management_models.PermissionType] |
Types of permissions. |
Source code in zenml/zen_stores/models/user_management_models.py
class Permission(BaseModel):
"""Pydantic object representing permissions on a specific resource.
Attributes:
operation: The operation for which the permissions are.
types: Types of permissions.
"""
operation: Operation
types: Set[PermissionType]
class Config:
# similar to non-mutable but also makes the object hashable
frozen = True
PermissionType (Enum)
All permission types.
Source code in zenml/zen_stores/models/user_management_models.py
class PermissionType(Enum):
"""All permission types."""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
Project (BaseModel)
pydantic-model
Pydantic object representing a project.
Attributes:
Name | Type | Description |
---|---|---|
id |
UUID |
Id of the project. |
creation_date |
datetime |
Date when the project was created. |
name |
str |
Name of the project. |
description |
Optional[str] |
Optional project description. |
Source code in zenml/zen_stores/models/user_management_models.py
class Project(BaseModel):
"""Pydantic object representing a project.
Attributes:
id: Id of the project.
creation_date: Date when the project was created.
name: Name of the project.
description: Optional project description.
"""
id: UUID = Field(default_factory=uuid4)
creation_date: datetime = Field(default_factory=datetime.now)
name: str
description: Optional[str] = None
Role (BaseModel)
pydantic-model
Pydantic object representing a role.
Attributes:
Name | Type | Description |
---|---|---|
id |
UUID |
Id of the role. |
creation_date |
datetime |
Date when the role was created. |
name |
str |
Name of the role. |
permissions |
Set[zenml.zen_stores.models.user_management_models.Permission] |
Set of permissions allowed by this role. |
Source code in zenml/zen_stores/models/user_management_models.py
class Role(BaseModel):
"""Pydantic object representing a role.
Attributes:
id: Id of the role.
creation_date: Date when the role was created.
name: Name of the role.
permissions: Set of permissions allowed by this role.
"""
id: UUID = Field(default_factory=uuid4)
creation_date: datetime = Field(default_factory=datetime.now)
name: str
permissions: Set[Permission] = set()
RoleAssignment (BaseModel)
pydantic-model
Pydantic object representing a role assignment.
Attributes:
Name | Type | Description |
---|---|---|
id |
UUID |
Id of the role assignment. |
creation_date |
datetime |
Date when the role was assigned. |
role_id |
UUID |
Id of the role. |
project_id |
Optional[uuid.UUID] |
Optional ID of a project that the role is limited to. |
team_id |
Optional[uuid.UUID] |
Id of a team to which the role is assigned. |
user_id |
Optional[uuid.UUID] |
Id of a user to which the role is assigned. |
Source code in zenml/zen_stores/models/user_management_models.py
class RoleAssignment(BaseModel):
"""Pydantic object representing a role assignment.
Attributes:
id: Id of the role assignment.
creation_date: Date when the role was assigned.
role_id: Id of the role.
project_id: Optional ID of a project that the role is limited to.
team_id: Id of a team to which the role is assigned.
user_id: Id of a user to which the role is assigned.
"""
id: UUID = Field(default_factory=uuid4)
creation_date: datetime = Field(default_factory=datetime.now)
role_id: UUID
project_id: Optional[UUID] = None
team_id: Optional[UUID] = None
user_id: Optional[UUID] = None
@root_validator
def ensure_single_entity(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validates that either `user_id` or `team_id` is set."""
user_id = values.get("user_id", None)
team_id = values.get("team_id", None)
if user_id and team_id:
raise ValueError("Only `user_id` or `team_id` is allowed.")
if not (user_id or team_id):
raise ValueError(
"Missing `user_id` or `team_id` for role assignment."
)
return values
ensure_single_entity(values)
classmethod
Validates that either user_id
or team_id
is set.
Source code in zenml/zen_stores/models/user_management_models.py
@root_validator
def ensure_single_entity(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validates that either `user_id` or `team_id` is set."""
user_id = values.get("user_id", None)
team_id = values.get("team_id", None)
if user_id and team_id:
raise ValueError("Only `user_id` or `team_id` is allowed.")
if not (user_id or team_id):
raise ValueError(
"Missing `user_id` or `team_id` for role assignment."
)
return values
Team (BaseModel)
pydantic-model
Pydantic object representing a team.
Attributes:
Name | Type | Description |
---|---|---|
id |
UUID |
Id of the team. |
creation_date |
datetime |
Date when the team was created. |
name |
str |
Name of the team. |
Source code in zenml/zen_stores/models/user_management_models.py
class Team(BaseModel):
"""Pydantic object representing a team.
Attributes:
id: Id of the team.
creation_date: Date when the team was created.
name: Name of the team.
"""
id: UUID = Field(default_factory=uuid4)
creation_date: datetime = Field(default_factory=datetime.now)
name: str
User (BaseModel)
pydantic-model
Pydantic object representing a user.
Attributes:
Name | Type | Description |
---|---|---|
id |
UUID |
Id of the user. |
creation_date |
datetime |
Date when the user was created. |
name |
str |
Name of the user. |
Source code in zenml/zen_stores/models/user_management_models.py
class User(BaseModel):
"""Pydantic object representing a user.
Attributes:
id: Id of the user.
creation_date: Date when the user was created.
name: Name of the user.
"""
id: UUID = Field(default_factory=uuid4)
creation_date: datetime = Field(default_factory=datetime.now)
name: str
# email: str
# password: str
zen_store_model
ZenStoreModel (FileSyncModel)
pydantic-model
Pydantic object used for serializing a ZenStore.
Attributes:
Name | Type | Description |
---|---|---|
version |
zenml version number |
|
stacks |
Dict[str, Dict[zenml.enums.StackComponentType, str]] |
Maps stack names to a configuration object containing the names and flavors of all stack components. |
stack_components |
DefaultDict[zenml.enums.StackComponentType, Dict[str, str]] |
Contains names and flavors of all registered stack components. |
stack_component_flavors |
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
Contains the flavor definitions of each stack component type |
users |
List[zenml.zen_stores.models.user_management_models.User] |
All registered users. |
teams |
List[zenml.zen_stores.models.user_management_models.Team] |
All registered teams. |
projects |
List[zenml.zen_stores.models.user_management_models.Project] |
All registered projects. |
roles |
List[zenml.zen_stores.models.user_management_models.Role] |
All registered roles. |
role_assignments |
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
All role assignments. |
team_assignments |
DefaultDict[str, Set[str]] |
Maps team names to names of users that are part of the team. |
Source code in zenml/zen_stores/models/zen_store_model.py
class ZenStoreModel(FileSyncModel):
"""Pydantic object used for serializing a ZenStore.
Attributes:
version: zenml version number
stacks: Maps stack names to a configuration object containing the
names and flavors of all stack components.
stack_components: Contains names and flavors of all registered stack
components.
stack_component_flavors: Contains the flavor definitions of each
stack component type
users: All registered users.
teams: All registered teams.
projects: All registered projects.
roles: All registered roles.
role_assignments: All role assignments.
team_assignments: Maps team names to names of users that are part of
the team.
"""
stacks: Dict[str, Dict[StackComponentType, str]] = Field(
default_factory=dict
)
stack_components: DefaultDict[StackComponentType, Dict[str, str]] = Field(
default=defaultdict(dict)
)
stack_component_flavors: List[FlavorWrapper] = Field(default_factory=list)
users: List[User] = Field(default_factory=list)
teams: List[Team] = Field(default_factory=list)
projects: List[Project] = Field(default_factory=list)
roles: List[Role] = Field(default_factory=list)
role_assignments: List[RoleAssignment] = Field(default_factory=list)
team_assignments: DefaultDict[str, Set[str]] = Field(
default=defaultdict(set)
)
@validator("stack_components")
def _construct_stack_components_defaultdict(
cls, stack_components: Dict[StackComponentType, Dict[str, str]]
) -> DefaultDict[StackComponentType, Dict[str, str]]:
"""Ensures that `stack_components` is a defaultdict so stack
components of a new component type can be added without issues."""
return defaultdict(dict, stack_components)
@validator("team_assignments")
def _construct_team_assignments_defaultdict(
cls, team_assignments: Dict[str, Set[str]]
) -> DefaultDict[str, Set[str]]:
"""Ensures that `team_assignments` is a defaultdict so users
of a new teams can be added without issues."""
return defaultdict(set, team_assignments)
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Ignore extra attributes from configs of previous ZenML versions
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/zen_stores/models/zen_store_model.py
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Ignore extra attributes from configs of previous ZenML versions
extra = "ignore"
ZenStorePipelineModel (FileSyncModel)
pydantic-model
Pydantic object used for serializing ZenStore pipelines and runs.
Attributes:
Name | Type | Description |
---|---|---|
pipeline_runs |
DefaultDict[str, List[zenml.zen_stores.models.pipeline_models.PipelineRunWrapper]] |
Maps pipeline names to runs of that pipeline. |
Source code in zenml/zen_stores/models/zen_store_model.py
class ZenStorePipelineModel(FileSyncModel):
"""Pydantic object used for serializing ZenStore pipelines and runs.
Attributes:
pipeline_runs: Maps pipeline names to runs of that pipeline.
"""
pipeline_runs: DefaultDict[str, List[PipelineRunWrapper]] = Field(
default=defaultdict(list)
)
@validator("pipeline_runs")
def _construct_pipeline_runs_defaultdict(
cls, pipeline_runs: Dict[str, List[PipelineRunWrapper]]
) -> DefaultDict[str, List[PipelineRunWrapper]]:
"""Ensures that `pipeline_runs` is a defaultdict so runs
of a new pipeline can be added without issues."""
return defaultdict(list, pipeline_runs)
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Ignore extra attributes from configs of previous ZenML versions
extra = "ignore"
Config
Pydantic configuration class.
Source code in zenml/zen_stores/models/zen_store_model.py
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Ignore extra attributes from configs of previous ZenML versions
extra = "ignore"
rest_zen_store
RestZenStore (BaseZenStore)
ZenStore implementation for accessing data from a REST api.
Source code in zenml/zen_stores/rest_zen_store.py
class RestZenStore(BaseZenStore):
"""ZenStore implementation for accessing data from a REST api."""
def initialize(
self,
url: str,
*args: Any,
**kwargs: Any,
) -> "RestZenStore":
"""Initializes a rest zen store instance.
Args:
url: Endpoint URL of the service for zen storage.
args: additional positional arguments (ignored).
kwargs: additional keyword arguments (ignored).
Returns:
The initialized zen store instance.
"""
if not self.is_valid_url(url.strip("/")):
raise ValueError("Invalid URL for REST store: {url}")
self._url = url.strip("/")
super().initialize(url, *args, **kwargs)
return self
def _migrate_store(self) -> None:
"""Migrates the store to the latest version."""
# Don't do anything here in the rest store, as the migration has to be
# done server-side.
# Static methods:
@staticmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the path from a URL, if it points or is backed by a local file.
Args:
url: The URL to get the path from.
Returns:
None, because there are no local paths from REST urls.
"""
return None
@staticmethod
def get_local_url(path: str) -> str:
"""Get a local URL for a given local path.
Args:
path: the path string to build a URL out of.
Returns:
Url pointing to the path for the store type.
Raises:
NotImplementedError: always
"""
raise NotImplementedError("Cannot build a REST url from a path.")
@staticmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is a valid local path."""
scheme = re.search("^([a-z0-9]+://)", url)
return (
scheme is not None
and scheme.group() in ("https://", "http://")
and url[-1] != "/"
)
# Public Interface:
@property
def type(self) -> StoreType:
"""The type of stack store."""
return StoreType.REST
@property
def url(self) -> str:
"""Get the stack store URL."""
return self._url
@property
def stacks_empty(self) -> bool:
"""Check if the store is empty (no stacks are configured).
The implementation of this method should check if the store is empty
without having to load all the stacks from the persistent storage.
"""
empty = self.get(STACKS_EMPTY)
if not isinstance(empty, bool):
raise ValueError(
f"Bad API Response. Expected boolean, got:\n{empty}"
)
return empty
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
return self._parse_stack_configuration(
self.get(f"{STACK_CONFIGURATIONS}/{name}")
)
@property
def stack_configurations(self) -> Dict[str, Dict[StackComponentType, str]]:
"""Configurations for all stacks registered in this stack store.
Returns:
Dictionary mapping stack names to Dict[StackComponentType, str]'s
"""
body = self.get(STACK_CONFIGURATIONS)
if not isinstance(body, dict):
raise ValueError(
f"Bad API Response. Expected dict, got {type(body)}"
)
return {
key: self._parse_stack_configuration(value)
for key, value in body.items()
}
def _register_stack_component(
self,
component: ComponentWrapper,
) -> None:
"""Register a stack component.
Args:
component: The component to register.
Raises:
KeyError: If a stack component with the same type
and name already exists.
"""
self.post(STACK_COMPONENTS, body=component)
def _update_stack_component(
self,
name: str,
component_type: StackComponentType,
component: ComponentWrapper,
) -> Dict[str, str]:
"""Update a stack component.
Args:
name: The original name of the stack component.
component_type: The type of the stack component to update.
component: The new component to update with.
Raises:
KeyError: If no stack component exists with the given name.
"""
body = self.put(
f"{STACK_COMPONENTS}/{component_type}/{name}", body=component
)
if isinstance(body, dict):
return cast(Dict[str, str], body)
else:
raise ValueError(
f"Bad API Response. Expected dict, got {type(body)}"
)
def _deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
Raises:
KeyError: If no stack exists for the given name.
"""
self.delete(f"{STACKS}/{name}")
def _save_stack(
self,
name: str,
stack_configuration: Dict[StackComponentType, str],
) -> None:
"""Add a stack to storage.
Args:
name: The name to save the stack as.
stack_configuration: Dict[StackComponentType, str] to persist.
"""
raise NotImplementedError
# Custom implementations:
@property
def stacks(self) -> List[StackWrapper]:
"""All stacks registered in this repository."""
body = self.get(STACKS)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [StackWrapper.parse_obj(s) for s in body]
def get_stack(self, name: str) -> StackWrapper:
"""Fetch a stack by name.
Args:
name: The name of the stack to retrieve.
Returns:
StackWrapper instance if the stack exists.
Raises:
KeyError: If no stack exists for the given name.
"""
return StackWrapper.parse_obj(self.get(f"{STACKS}/{name}"))
def _register_stack(self, stack: StackWrapper) -> None:
"""Register a stack and its components.
If any of the stacks' components aren't registered in the stack store
yet, this method will try to register them as well.
Args:
stack: The stack to register.
Raises:
StackExistsError: If a stack with the same name already exists.
StackComponentExistsError: If a component of the stack wasn't
registered and a different component with the same name
already exists.
"""
self.post(STACKS, stack)
def _update_stack(self, name: str, stack: StackWrapper) -> None:
"""Update a stack and its components.
If any of the stack's components aren't registered in the stack store
yet, this method will try to register them as well.
Args:
name: The original name of the stack.
stack: The new stack to use in the update.
"""
self.put(f"{STACKS}/{name}", body=stack)
if name != stack.name:
self.deregister_stack(name)
def get_stack_component(
self, component_type: StackComponentType, name: str
) -> ComponentWrapper:
"""Get a registered stack component.
Raises:
KeyError: If no component with the requested type and name exists.
"""
return ComponentWrapper.parse_obj(
self.get(f"{STACK_COMPONENTS}/{component_type}/{name}")
)
def get_stack_components(
self, component_type: StackComponentType
) -> List[ComponentWrapper]:
"""Fetches all registered stack components of the given type.
Args:
component_type: StackComponentType to list members of
Returns:
A list of StackComponentConfiguration instances.
"""
body = self.get(f"{STACK_COMPONENTS}/{component_type}")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [ComponentWrapper.parse_obj(c) for c in body]
def deregister_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Deregisters a stack component.
Args:
component_type: The type of the component to deregister.
name: The name of the component to deregister.
Raises:
ValueError: if trying to deregister a component that's part
of a stack.
"""
self.delete(f"{STACK_COMPONENTS}/{component_type}/{name}")
# User, project and role management
@property
def users(self) -> List[User]:
"""All registered users.
Returns:
A list of all registered users.
Raises:
ValueError: In case of a bad API response.
"""
body = self.get(USERS)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [User.parse_obj(user_dict) for user_dict in body]
def _get_user(self, user_name: str) -> User:
"""Get a specific user by name.
Args:
user_name: Name of the user to get.
Returns:
The requested user, if it was found.
Raises:
KeyError: If no user with the given name exists.
"""
return User.parse_obj(self.get(f"{USERS}/{user_name}"))
def _create_user(self, user_name: str) -> User:
"""Creates a new user.
Args:
user_name: Unique username.
Returns:
The newly created user.
Raises:
EntityExistsError: If a user with the given name already exists.
"""
user = User(name=user_name)
return User.parse_obj(self.post(USERS, body=user))
def _delete_user(self, user_name: str) -> None:
"""Deletes a user.
Args:
user_name: Name of the user to delete.
Raises:
KeyError: If no user with the given name exists.
"""
self.delete(f"{USERS}/{user_name}")
@property
def teams(self) -> List[Team]:
"""All registered teams.
Returns:
A list of all registered teams.
Raises:
ValueError: In case of a bad API response.
"""
body = self.get(TEAMS)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [Team.parse_obj(team_dict) for team_dict in body]
def _get_team(self, team_name: str) -> Team:
"""Gets a specific team.
Args:
team_name: Name of the team to get.
Returns:
The requested team.
Raises:
KeyError: If no team with the given name exists.
"""
return Team.parse_obj(self.get(f"{TEAMS}/{team_name}"))
def _create_team(self, team_name: str) -> Team:
"""Creates a new team.
Args:
team_name: Unique team name.
Returns:
The newly created team.
Raises:
EntityExistsError: If a team with the given name already exists.
"""
team = Team(name=team_name)
return Team.parse_obj(self.post(TEAMS, body=team))
def _delete_team(self, team_name: str) -> None:
"""Deletes a team.
Args:
team_name: Name of the team to delete.
Raises:
KeyError: If no team with the given name exists.
"""
self.delete(f"{TEAMS}/{team_name}")
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
user = User(name=user_name)
self.post(f"{TEAMS}/{team_name}/users", user)
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
self.delete(f"{TEAMS}/{team_name}/users/{user_name}")
@property
def projects(self) -> List[Project]:
"""All registered projects.
Returns:
A list of all registered projects.
Raises:
ValueError: In case of a bad API response.
"""
body = self.get(PROJECTS)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [Project.parse_obj(project_dict) for project_dict in body]
def _get_project(self, project_name: str) -> Project:
"""Get an existing project by name.
Args:
project_name: Name of the project to get.
Returns:
The requested project if one was found.
Raises:
KeyError: If there is no such project.
"""
return Project.parse_obj(self.get(f"{PROJECTS}/{project_name}"))
def _create_project(
self, project_name: str, description: Optional[str] = None
) -> Project:
"""Creates a new project.
Args:
project_name: Unique project name.
description: Optional project description.
Returns:
The newly created project.
Raises:
EntityExistsError: If a project with the given name already exists.
"""
project = Project(name=project_name, description=description)
return Project.parse_obj(self.post(PROJECTS, body=project))
def _delete_project(self, project_name: str) -> None:
"""Deletes a project.
Args:
project_name: Name of the project to delete.
Raises:
KeyError: If no project with the given name exists.
"""
self.delete(f"{PROJECTS}/{project_name}")
@property
def roles(self) -> List[Role]:
"""All registered roles.
Returns:
A list of all registered roles.
Raises:
ValueError: In case of a bad API response.
"""
body = self.get(ROLES)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [Role.parse_obj(role_dict) for role_dict in body]
@property
def role_assignments(self) -> List[RoleAssignment]:
"""All registered role assignments.
Returns:
A list of all registered role assignments.
Raises:
ValueError: In case of a bad API response.
"""
body = self.get(ROLE_ASSIGNMENTS)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [
RoleAssignment.parse_obj(assignment_dict)
for assignment_dict in body
]
def _get_role(self, role_name: str) -> Role:
"""Gets a specific role.
Args:
role_name: Name of the role to get.
Returns:
The requested role.
Raises:
KeyError: If no role with the given name exists.
"""
return Role.parse_obj(self.get(f"{ROLES}/{role_name}"))
def _create_role(self, role_name: str) -> Role:
"""Creates a new role.
Args:
role_name: Unique role name.
Returns:
The newly created role.
Raises:
EntityExistsError: If a role with the given name already exists.
"""
role = Role(name=role_name)
return Role.parse_obj(self.post(ROLES, body=role))
def _delete_role(self, role_name: str) -> None:
"""Deletes a role.
Args:
role_name: Name of the role to delete.
Raises:
KeyError: If no role with the given name exists.
"""
self.delete(f"{ROLES}/{role_name}")
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
data = {
"role_name": role_name,
"entity_name": entity_name,
"project_name": project_name,
"is_user": is_user,
}
self._handle_response(
requests.post(
self.url + ROLE_ASSIGNMENTS,
json=data,
auth=self._get_authentication(),
)
)
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
data = {
"role_name": role_name,
"entity_name": entity_name,
"project_name": project_name,
"is_user": is_user,
}
self._handle_response(
requests.delete(
self.url + ROLE_ASSIGNMENTS,
json=data,
auth=self._get_authentication(),
)
)
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
ValueError: In case of a bad API response.
"""
body = self.get(f"{TEAMS}/{team_name}/users")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [User.parse_obj(user_dict) for user_dict in body]
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
ValueError: In case of a bad API response.
"""
body = self.get(f"{USERS}/{user_name}/teams")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [Team.parse_obj(team_dict) for team_dict in body]
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
ValueError: In case of a bad API response.
"""
path = f"{USERS}/{user_name}/role_assignments"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
assignments = [
RoleAssignment.parse_obj(assignment_dict)
for assignment_dict in body
]
if include_team_roles:
for team in self.get_teams_for_user(user_name):
assignments += self.get_role_assignments_for_team(
team.name, project_name=project_name
)
return assignments
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no user or project with the given names exists.
ValueError: In case of a bad API response.
"""
path = f"{TEAMS}/{team_name}/role_assignments"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [
RoleAssignment.parse_obj(assignment_dict)
for assignment_dict in body
]
# Pipelines and pipeline runs
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
path = f"{PIPELINE_RUNS}/{pipeline_name}/{run_name}"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
return PipelineRunWrapper.parse_obj(body)
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
path = f"{PIPELINE_RUNS}/{pipeline_name}"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [PipelineRunWrapper.parse_obj(dict_) for dict_ in body]
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
self.post(PIPELINE_RUNS, body=pipeline_run)
# Private interface shall not be implemented for REST store, instead the
# API only provides all public methods, including the ones that would
# otherwise be inherited from the BaseZenStore in other implementations.
# Don't call these! ABC complains that they aren't implemented, but they
# aren't needed with the custom implementations of base methods.
def _create_stack(
self, name: str, stack_configuration: Dict[StackComponentType, str]
) -> None:
"""Add a stack to storage"""
raise NotImplementedError("Not to be accessed directly in client!")
def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
) -> Tuple[str, bytes]:
"""Fetch the flavor and configuration for a stack component."""
raise NotImplementedError("Not to be accessed directly in client!")
def _get_stack_component_names(
self, component_type: StackComponentType
) -> List[str]:
"""Get names of all registered stack components of a given type."""
raise NotImplementedError("Not to be accessed directly in client!")
def _delete_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Remove a StackComponent from storage."""
raise NotImplementedError("Not to be accessed directly in client!")
# Handling stack component flavors
@property
def flavors(self) -> List[FlavorWrapper]:
"""All registered flavors.
Returns:
A list of all registered flavors.
"""
body = self.get(FLAVORS)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [FlavorWrapper.parse_obj(flavor_dict) for flavor_dict in body]
def _create_flavor(
self,
source: str,
name: str,
stack_component_type: StackComponentType,
) -> FlavorWrapper:
"""Creates a new flavor.
Args:
source: the source path to the implemented flavor.
name: the name of the flavor.
stack_component_type: the corresponding StackComponentType.
Returns:
The newly created flavor.
Raises:
EntityExistsError: If a flavor with the given name and type
already exists.
"""
flavor = FlavorWrapper(
name=name,
source=source,
type=stack_component_type,
)
return FlavorWrapper.parse_obj(self.post(FLAVORS, body=flavor))
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
body = self.get(f"{FLAVORS}/{component_type}")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [FlavorWrapper.parse_obj(flavor_dict) for flavor_dict in body]
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
return FlavorWrapper.parse_obj(
self.get(f"{FLAVORS}/{component_type}/{flavor_name}")
)
# Implementation specific methods:
def _parse_stack_configuration(
self, to_parse: Json
) -> Dict[StackComponentType, str]:
"""Parse an API response into `Dict[StackComponentType, str]`."""
if not isinstance(to_parse, dict):
raise ValueError(
f"Bad API Response. Expected dict, got {type(to_parse)}."
)
return {
StackComponentType(typ): component_name
for typ, component_name in to_parse.items()
}
def _handle_response(self, response: requests.Response) -> Json:
"""Handle API response, translating http status codes to Exception."""
if response.status_code >= 200 and response.status_code < 300:
try:
payload: Json = response.json()
return payload
except requests.exceptions.JSONDecodeError:
raise ValueError(
"Bad response from API. Expected json, got\n"
f"{response.text}"
)
elif response.status_code == 401:
raise requests.HTTPError(
f"{response.status_code} Client Error: Unauthorized request to URL {response.url}: {response.json().get('detail')}"
)
elif response.status_code == 404:
if "DoesNotExistException" not in response.text:
raise KeyError(*response.json().get("detail", (response.text,)))
message = ": ".join(response.json().get("detail", (response.text,)))
raise DoesNotExistException(message)
elif response.status_code == 409:
if "StackComponentExistsError" in response.text:
raise StackComponentExistsError(
*response.json().get("detail", (response.text,))
)
elif "StackExistsError" in response.text:
raise StackExistsError(
*response.json().get("detail", (response.text,))
)
elif "EntityExistsError" in response.text:
raise EntityExistsError(
*response.json().get("detail", (response.text,))
)
else:
raise ValueError(
*response.json().get("detail", (response.text,))
)
elif response.status_code == 422:
raise RuntimeError(*response.json().get("detail", (response.text,)))
elif response.status_code == 500:
raise KeyError(response.text)
else:
raise RuntimeError(
"Error retrieving from API. Got response "
f"{response.status_code} with body:\n{response.text}"
)
@staticmethod
def _get_authentication() -> Tuple[str, str]:
"""Gets HTTP basic auth credentials."""
from zenml.repository import Repository
return Repository().active_user_name, ""
def get(self, path: str) -> Json:
"""Make a GET request to the given endpoint path."""
return self._handle_response(
requests.get(self.url + path, auth=self._get_authentication())
)
def delete(self, path: str) -> Json:
"""Make a DELETE request to the given endpoint path."""
return self._handle_response(
requests.delete(self.url + path, auth=self._get_authentication())
)
def post(self, path: str, body: BaseModel) -> Json:
"""Make a POST request to the given endpoint path."""
endpoint = self.url + path
return self._handle_response(
requests.post(
endpoint, data=body.json(), auth=self._get_authentication()
)
)
def put(self, path: str, body: BaseModel) -> Json:
"""Make a PUT request to the given endpoint path."""
endpoint = self.url + path
return self._handle_response(
requests.put(
endpoint, data=body.json(), auth=self._get_authentication()
)
)
flavors: List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]
property
readonly
All registered flavors.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
A list of all registered flavors. |
projects: List[zenml.zen_stores.models.user_management_models.Project]
property
readonly
All registered projects.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Project] |
A list of all registered projects. |
Exceptions:
Type | Description |
---|---|
ValueError |
In case of a bad API response. |
role_assignments: List[zenml.zen_stores.models.user_management_models.RoleAssignment]
property
readonly
All registered role assignments.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
A list of all registered role assignments. |
Exceptions:
Type | Description |
---|---|
ValueError |
In case of a bad API response. |
roles: List[zenml.zen_stores.models.user_management_models.Role]
property
readonly
All registered roles.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Role] |
A list of all registered roles. |
Exceptions:
Type | Description |
---|---|
ValueError |
In case of a bad API response. |
stack_configurations: Dict[str, Dict[zenml.enums.StackComponentType, str]]
property
readonly
Configurations for all stacks registered in this stack store.
Returns:
Type | Description |
---|---|
Dict[str, Dict[zenml.enums.StackComponentType, str]] |
Dictionary mapping stack names to Dict[StackComponentType, str]'s |
stacks: List[zenml.zen_stores.models.stack_wrapper.StackWrapper]
property
readonly
All stacks registered in this repository.
stacks_empty: bool
property
readonly
Check if the store is empty (no stacks are configured).
The implementation of this method should check if the store is empty without having to load all the stacks from the persistent storage.
teams: List[zenml.zen_stores.models.user_management_models.Team]
property
readonly
All registered teams.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
A list of all registered teams. |
Exceptions:
Type | Description |
---|---|
ValueError |
In case of a bad API response. |
type: StoreType
property
readonly
The type of stack store.
url: str
property
readonly
Get the stack store URL.
users: List[zenml.zen_stores.models.user_management_models.User]
property
readonly
All registered users.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
A list of all registered users. |
Exceptions:
Type | Description |
---|---|
ValueError |
In case of a bad API response. |
add_user_to_team(self, team_name, user_name)
Adds a user to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
user = User(name=user_name)
self.post(f"{TEAMS}/{team_name}/users", user)
assign_role(self, role_name, entity_name, project_name=None, is_user=True)
Assigns a role to a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to assign. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
data = {
"role_name": role_name,
"entity_name": entity_name,
"project_name": project_name,
"is_user": is_user,
}
self._handle_response(
requests.post(
self.url + ROLE_ASSIGNMENTS,
json=data,
auth=self._get_authentication(),
)
)
delete(self, path)
Make a DELETE request to the given endpoint path.
Source code in zenml/zen_stores/rest_zen_store.py
def delete(self, path: str) -> Json:
"""Make a DELETE request to the given endpoint path."""
return self._handle_response(
requests.delete(self.url + path, auth=self._get_authentication())
)
deregister_stack_component(self, component_type, name)
Deregisters a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the component to deregister. |
required |
name |
str |
The name of the component to deregister. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if trying to deregister a component that's part of a stack. |
Source code in zenml/zen_stores/rest_zen_store.py
def deregister_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Deregisters a stack component.
Args:
component_type: The type of the component to deregister.
name: The name of the component to deregister.
Raises:
ValueError: if trying to deregister a component that's part
of a stack.
"""
self.delete(f"{STACK_COMPONENTS}/{component_type}/{name}")
get(self, path)
Make a GET request to the given endpoint path.
Source code in zenml/zen_stores/rest_zen_store.py
def get(self, path: str) -> Json:
"""Make a GET request to the given endpoint path."""
return self._handle_response(
requests.get(self.url + path, auth=self._get_authentication())
)
get_flavor_by_name_and_type(self, flavor_name, component_type)
Fetch a flavor by a given name and type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_name |
str |
The name of the flavor. |
required |
component_type |
StackComponentType |
Optional, the type of the component. |
required |
Returns:
Type | Description |
---|---|
FlavorWrapper |
Flavor instance if it exists |
Exceptions:
Type | Description |
---|---|
KeyError |
If no flavor exists with the given name and type or there are more than one instances |
Source code in zenml/zen_stores/rest_zen_store.py
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
return FlavorWrapper.parse_obj(
self.get(f"{FLAVORS}/{component_type}/{flavor_name}")
)
get_flavors_by_type(self, component_type)
Fetch all flavor defined for a specific stack component type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the stack component. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of all the flavors for the given stack component type. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
body = self.get(f"{FLAVORS}/{component_type}")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [FlavorWrapper.parse_obj(flavor_dict) for flavor_dict in body]
get_local_url(path)
staticmethod
Get a local URL for a given local path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
the path string to build a URL out of. |
required |
Returns:
Type | Description |
---|---|
str |
Url pointing to the path for the store type. |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
always |
Source code in zenml/zen_stores/rest_zen_store.py
@staticmethod
def get_local_url(path: str) -> str:
"""Get a local URL for a given local path.
Args:
path: the path string to build a URL out of.
Returns:
Url pointing to the path for the store type.
Raises:
NotImplementedError: always
"""
raise NotImplementedError("Cannot build a REST url from a path.")
get_path_from_url(url)
staticmethod
Get the path from a URL, if it points or is backed by a local file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL to get the path from. |
required |
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
None, because there are no local paths from REST urls. |
Source code in zenml/zen_stores/rest_zen_store.py
@staticmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the path from a URL, if it points or is backed by a local file.
Args:
url: The URL to get the path from.
Returns:
None, because there are no local paths from REST urls.
"""
return None
get_pipeline_run(self, pipeline_name, run_name, project_name=None)
Gets a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get the run. |
required |
run_name |
str |
Name of the pipeline run to get. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline run. |
None |
Exceptions:
Type | Description |
---|---|
KeyError |
If no pipeline run (or project) with the given name exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
path = f"{PIPELINE_RUNS}/{pipeline_name}/{run_name}"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
return PipelineRunWrapper.parse_obj(body)
get_pipeline_runs(self, pipeline_name, project_name=None)
Gets pipeline runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get runs. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline runs. |
None |
Source code in zenml/zen_stores/rest_zen_store.py
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
path = f"{PIPELINE_RUNS}/{pipeline_name}"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [PipelineRunWrapper.parse_obj(dict_) for dict_ in body]
get_role_assignments_for_team(self, team_name, project_name=None)
Fetches all role assignments for a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user or project with the given names exists. |
ValueError |
In case of a bad API response. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no user or project with the given names exists.
ValueError: In case of a bad API response.
"""
path = f"{TEAMS}/{team_name}/role_assignments"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [
RoleAssignment.parse_obj(assignment_dict)
for assignment_dict in body
]
get_role_assignments_for_user(self, user_name, project_name=None, include_team_roles=True)
Fetches all role assignments for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
include_team_roles |
bool |
If |
True |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this user. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user or project with the given names exists. |
ValueError |
In case of a bad API response. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
ValueError: In case of a bad API response.
"""
path = f"{USERS}/{user_name}/role_assignments"
if project_name:
path += f"?project_name={project_name}"
body = self.get(path)
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
assignments = [
RoleAssignment.parse_obj(assignment_dict)
for assignment_dict in body
]
if include_team_roles:
for team in self.get_teams_for_user(user_name):
assignments += self.get_role_assignments_for_team(
team.name, project_name=project_name
)
return assignments
get_stack(self, name)
Fetch a stack by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to retrieve. |
required |
Returns:
Type | Description |
---|---|
StackWrapper |
StackWrapper instance if the stack exists. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_stack(self, name: str) -> StackWrapper:
"""Fetch a stack by name.
Args:
name: The name of the stack to retrieve.
Returns:
StackWrapper instance if the stack exists.
Raises:
KeyError: If no stack exists for the given name.
"""
return StackWrapper.parse_obj(self.get(f"{STACKS}/{name}"))
get_stack_component(self, component_type, name)
Get a registered stack component.
Exceptions:
Type | Description |
---|---|
KeyError |
If no component with the requested type and name exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_stack_component(
self, component_type: StackComponentType, name: str
) -> ComponentWrapper:
"""Get a registered stack component.
Raises:
KeyError: If no component with the requested type and name exists.
"""
return ComponentWrapper.parse_obj(
self.get(f"{STACK_COMPONENTS}/{component_type}/{name}")
)
get_stack_components(self, component_type)
Fetches all registered stack components of the given type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
StackComponentType to list members of |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.component_wrapper.ComponentWrapper] |
A list of StackComponentConfiguration instances. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_stack_components(
self, component_type: StackComponentType
) -> List[ComponentWrapper]:
"""Fetches all registered stack components of the given type.
Args:
component_type: StackComponentType to list members of
Returns:
A list of StackComponentConfiguration instances.
"""
body = self.get(f"{STACK_COMPONENTS}/{component_type}")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [ComponentWrapper.parse_obj(c) for c in body]
get_stack_configuration(self, name)
Fetches a stack configuration by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to fetch. |
required |
Returns:
Type | Description |
---|---|
Dict[zenml.enums.StackComponentType, str] |
Dict[StackComponentType, str] for the requested stack name. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
return self._parse_stack_configuration(
self.get(f"{STACK_CONFIGURATIONS}/{name}")
)
get_teams_for_user(self, user_name)
Fetches all teams for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
List of teams that the user is part of. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user with the given name exists. |
ValueError |
In case of a bad API response. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
ValueError: In case of a bad API response.
"""
body = self.get(f"{USERS}/{user_name}/teams")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [Team.parse_obj(team_dict) for team_dict in body]
get_users_for_team(self, team_name)
Fetches all users of a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
List of users that are part of the team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team with the given name exists. |
ValueError |
In case of a bad API response. |
Source code in zenml/zen_stores/rest_zen_store.py
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
ValueError: In case of a bad API response.
"""
body = self.get(f"{TEAMS}/{team_name}/users")
if not isinstance(body, list):
raise ValueError(
f"Bad API Response. Expected list, got {type(body)}"
)
return [User.parse_obj(user_dict) for user_dict in body]
initialize(self, url, *args, **kwargs)
Initializes a rest zen store instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
Endpoint URL of the service for zen storage. |
required |
args |
Any |
additional positional arguments (ignored). |
() |
kwargs |
Any |
additional keyword arguments (ignored). |
{} |
Returns:
Type | Description |
---|---|
RestZenStore |
The initialized zen store instance. |
Source code in zenml/zen_stores/rest_zen_store.py
def initialize(
self,
url: str,
*args: Any,
**kwargs: Any,
) -> "RestZenStore":
"""Initializes a rest zen store instance.
Args:
url: Endpoint URL of the service for zen storage.
args: additional positional arguments (ignored).
kwargs: additional keyword arguments (ignored).
Returns:
The initialized zen store instance.
"""
if not self.is_valid_url(url.strip("/")):
raise ValueError("Invalid URL for REST store: {url}")
self._url = url.strip("/")
super().initialize(url, *args, **kwargs)
return self
is_valid_url(url)
staticmethod
Check if the given url is a valid local path.
Source code in zenml/zen_stores/rest_zen_store.py
@staticmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is a valid local path."""
scheme = re.search("^([a-z0-9]+://)", url)
return (
scheme is not None
and scheme.group() in ("https://", "http://")
and url[-1] != "/"
)
post(self, path, body)
Make a POST request to the given endpoint path.
Source code in zenml/zen_stores/rest_zen_store.py
def post(self, path: str, body: BaseModel) -> Json:
"""Make a POST request to the given endpoint path."""
endpoint = self.url + path
return self._handle_response(
requests.post(
endpoint, data=body.json(), auth=self._get_authentication()
)
)
put(self, path, body)
Make a PUT request to the given endpoint path.
Source code in zenml/zen_stores/rest_zen_store.py
def put(self, path: str, body: BaseModel) -> Json:
"""Make a PUT request to the given endpoint path."""
endpoint = self.url + path
return self._handle_response(
requests.put(
endpoint, data=body.json(), auth=self._get_authentication()
)
)
register_pipeline_run(self, pipeline_run)
Registers a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunWrapper |
The pipeline run to register. |
required |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a pipeline run with the same name already exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
self.post(PIPELINE_RUNS, body=pipeline_run)
remove_user_from_team(self, team_name, user_name)
Removes a user from a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
self.delete(f"{TEAMS}/{team_name}/users/{user_name}")
revoke_role(self, role_name, entity_name, project_name=None, is_user=True)
Revokes a role from a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to revoke. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/rest_zen_store.py
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
data = {
"role_name": role_name,
"entity_name": entity_name,
"project_name": project_name,
"is_user": is_user,
}
self._handle_response(
requests.delete(
self.url + ROLE_ASSIGNMENTS,
json=data,
auth=self._get_authentication(),
)
)
sql_zen_store
SqlZenStore (BaseZenStore)
Repository Implementation that uses SQL database backend
Source code in zenml/zen_stores/sql_zen_store.py
class SqlZenStore(BaseZenStore):
"""Repository Implementation that uses SQL database backend"""
def initialize(
self,
url: str,
*args: Any,
**kwargs: Any,
) -> "SqlZenStore":
"""Initialize a new SqlZenStore.
Args:
url: odbc path to a database.
args, kwargs: additional parameters for SQLModel.
Returns:
The initialized zen store instance.
"""
if not self.is_valid_url(url):
raise ValueError(f"Invalid URL for SQL store: {url}")
logger.debug("Initializing SqlZenStore at %s", url)
self._url = url
local_path = self.get_path_from_url(url)
if local_path:
utils.create_dir_recursive_if_not_exists(str(local_path.parent))
# we need to remove `skip_default_registrations` from the kwargs,
# because SQLModel will raise an error if it is present
sql_kwargs = kwargs.copy()
sql_kwargs.pop("skip_default_registrations", False)
sql_kwargs.pop("track_analytics", False)
sql_kwargs.pop("skip_migration", False)
self.engine = create_engine(url, *args, **sql_kwargs)
SQLModel.metadata.create_all(self.engine)
with Session(self.engine) as session:
if not session.exec(select(ZenUser)).first():
session.add(ZenUser(id=1, name="LocalZenUser"))
session.commit()
super().initialize(url, *args, **kwargs)
return self
# Public interface implementations:
@property
def type(self) -> StoreType:
"""The type of zen store."""
return StoreType.SQL
@property
def url(self) -> str:
"""URL of the repository."""
if not self._url:
raise RuntimeError(
"SQL zen store has not been initialized. Call `initialize` "
"before using the store."
)
return self._url
# Static methods:
@staticmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the local path from a URL, if it points to a local sqlite file.
This method first checks that the URL is a valid SQLite URL, which is
backed by a file in the local filesystem. All other types of supported
SQLAlchemy connection URLs are considered non-local and won't return
a valid local path.
Args:
url: The URL to get the path from.
Returns:
The path extracted from the URL, or None, if the URL does not
point to a local sqlite file.
"""
if not SqlZenStore.is_valid_url(url):
raise ValueError(f"Invalid URL for SQL store: {url}")
if not url.startswith("sqlite:///"):
return None
url = url.replace("sqlite:///", "")
return Path(url)
@staticmethod
def get_local_url(path: str) -> str:
"""Get a local SQL url for a given local path."""
return f"sqlite:///{path}/zenml.db"
@staticmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is a valid SQL url."""
try:
make_url(url)
except ArgumentError:
logger.debug("Invalid SQL URL: %s", url)
return False
return True
@property
def stacks_empty(self) -> bool:
"""Check if the zen store is empty."""
with Session(self.engine) as session:
return session.exec(select(ZenStack)).first() is None
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
logger.debug("Fetching stack with name '%s'.", name)
# first check that the stack exists
with Session(self.engine) as session:
maybe_stack = session.exec(
select(ZenStack).where(ZenStack.name == name)
).first()
if maybe_stack is None:
raise KeyError(
f"Unable to find stack with name '{name}'. Available names: "
f"{set(self.stack_names)}."
)
# then get all components assigned to that stack
with Session(self.engine) as session:
definitions_and_components = session.exec(
select(ZenStackDefinition, ZenStackComponent)
.where(
ZenStackDefinition.component_type
== ZenStackComponent.component_type
)
.where(
ZenStackDefinition.component_name == ZenStackComponent.name
)
.where(ZenStackDefinition.stack_name == name)
)
params = {
component.component_type: component.name
for _, component in definitions_and_components
}
return {StackComponentType(typ): name for typ, name in params.items()}
@property
def stack_configurations(self) -> Dict[str, Dict[StackComponentType, str]]:
"""Configuration for all stacks registered in this zen store.
Returns:
Dictionary mapping stack names to Dict[StackComponentType, str]
"""
return {n: self.get_stack_configuration(n) for n in self.stack_names}
def _register_stack_component(
self,
component: ComponentWrapper,
) -> None:
"""Register a stack component.
Args:
component: The component to register.
Raises:
StackComponentExistsError: If a stack component with the same type
and name already exists.
"""
with Session(self.engine) as session:
existing_component = session.exec(
select(ZenStackComponent)
.where(ZenStackComponent.name == component.name)
.where(ZenStackComponent.component_type == component.type)
).first()
if existing_component is not None:
raise StackComponentExistsError(
f"Unable to register stack component (type: "
f"{component.type}) with name '{component.name}': Found "
f"existing stack component with this name."
)
new_component = ZenStackComponent(
component_type=component.type,
name=component.name,
component_flavor=component.flavor,
configuration=component.config,
)
session.add(new_component)
session.commit()
def _update_stack_component(
self,
name: str,
component_type: StackComponentType,
component: ComponentWrapper,
) -> Dict[str, str]:
"""Update a stack component.
Args:
name: The original name of the stack component.
component_type: The type of the stack component to update.
component: The new component to update with.
Raises:
KeyError: If no stack component exists with the given name.
"""
with Session(self.engine) as session:
updated_component = session.exec(
select(ZenStackComponent)
.where(ZenStackComponent.component_type == component_type)
.where(ZenStackComponent.name == name)
).first()
if not updated_component:
raise KeyError(
f"Unable to update stack component (type: "
f"{component.type}) with name '{component.name}': No "
f"existing stack component found with this name."
)
new_name_component = session.exec(
select(ZenStackComponent)
.where(ZenStackComponent.component_type == component_type)
.where(ZenStackComponent.name == component.name)
).first()
if (name != component.name) and new_name_component is not None:
raise StackComponentExistsError(
f"Unable to update stack component (type: "
f"{component.type}) with name '{component.name}': Found "
f"existing stack component with this name."
)
updated_component.configuration = component.config
# handle any potential renamed component
updated_component.name = component.name
# rename components inside stacks
updated_stack_definitions = session.exec(
select(ZenStackDefinition)
.where(ZenStackDefinition.component_type == component_type)
.where(ZenStackDefinition.component_name == name)
).all()
for stack_definition in updated_stack_definitions:
stack_definition.component_name = component.name
session.add(stack_definition)
session.add(updated_component)
session.commit()
logger.info(
"Updated stack component with type '%s' and name '%s'.",
component_type,
component.name,
)
return {component.type.value: component.flavor}
def _deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
Raises:
KeyError: If no stack exists for the given name.
"""
with Session(self.engine) as session:
try:
stack = session.exec(
select(ZenStack).where(ZenStack.name == name)
).one()
session.delete(stack)
except NoResultFound as error:
raise KeyError from error
definitions = session.exec(
select(ZenStackDefinition).where(
ZenStackDefinition.stack_name == name
)
).all()
for definition in definitions:
session.delete(definition)
session.commit()
# Private interface implementations:
def _save_stack(
self,
name: str,
stack_configuration: Dict[StackComponentType, str],
) -> None:
"""Save a stack.
Args:
name: The name to save the stack as.
stack_configuration: Dict[StackComponentType, str] to persist.
"""
with Session(self.engine) as session:
stack = session.exec(
select(ZenStack).where(ZenStack.name == name)
).first()
if stack is None:
stack = ZenStack(name=name, created_by=1)
session.add(stack)
else:
# clear the existing stack definitions for a stack
# that is about to be updated
query = select(ZenStackDefinition).where(
ZenStackDefinition.stack_name == name
)
for result in session.exec(query).all():
session.delete(result)
for ctype, cname in stack_configuration.items():
statement = (
select(ZenStackDefinition)
.where(ZenStackDefinition.stack_name == name)
.where(ZenStackDefinition.component_type == ctype)
)
results = session.exec(statement)
component = results.one_or_none()
if component is None:
session.add(
ZenStackDefinition(
stack_name=name,
component_type=ctype,
component_name=cname,
)
)
else:
component.component_name = cname
component.component_type = ctype
session.add(component)
session.commit()
def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
) -> Tuple[str, bytes]:
"""Fetch the flavor and configuration for a stack component.
Args:
component_type: The type of the component to fetch.
name: The name of the component to fetch.
Returns:
Pair of (flavor, configuration) for stack component, as string and
base64-encoded yaml document, respectively
Raises:
KeyError: If no stack component exists for the given type and name.
"""
with Session(self.engine) as session:
component = session.exec(
select(ZenStackComponent)
.where(ZenStackComponent.component_type == component_type)
.where(ZenStackComponent.name == name)
).one_or_none()
if component is None:
raise KeyError(
f"Unable to find stack component (type: {component_type}) "
f"with name '{name}'."
)
return component.component_flavor, component.configuration
def _get_stack_component_names(
self, component_type: StackComponentType
) -> List[str]:
"""Get names of all registered stack components of a given type.
Args:
component_type: The type of the component to list names for.
Returns:
A list of names as strings.
"""
with Session(self.engine) as session:
statement = select(ZenStackComponent).where(
ZenStackComponent.component_type == component_type
)
return [component.name for component in session.exec(statement)]
def _delete_stack_component(
self, component_type: StackComponentType, name: str
) -> None:
"""Remove a StackComponent from storage.
Args:
component_type: The type of component to delete.
name: Then name of the component to delete.
Raises:
KeyError: If no component exists for given type and name.
"""
with Session(self.engine) as session:
component = session.exec(
select(ZenStackComponent)
.where(ZenStackComponent.component_type == component_type)
.where(ZenStackComponent.name == name)
).first()
if component is not None:
session.delete(component)
session.commit()
else:
raise KeyError(
"Unable to deregister stack component (type: "
f"{component_type.value}) with name '{name}': No stack "
"component exists with this name."
)
# User, project and role management
@property
def users(self) -> List[User]:
"""All registered users.
Returns:
A list of all registered users.
"""
with Session(self.engine) as session:
return [
User(**user.dict())
for user in session.exec(select(UserTable)).all()
]
def _get_user(self, user_name: str) -> User:
"""Get a specific user by name.
Args:
user_name: Name of the user to get.
Returns:
The requested user, if it was found.
Raises:
KeyError: If no user with the given name exists.
"""
with Session(self.engine) as session:
try:
user = session.exec(
select(UserTable).where(UserTable.name == user_name)
).one()
except NoResultFound as error:
raise KeyError from error
return User(**user.dict())
def _create_user(self, user_name: str) -> User:
"""Creates a new user.
Args:
user_name: Unique username.
Returns:
The newly created user.
Raises:
EntityExistsError: If a user with the given name already exists.
"""
with Session(self.engine) as session:
existing_user = session.exec(
select(UserTable).where(UserTable.name == user_name)
).first()
if existing_user:
raise EntityExistsError(
f"User with name '{user_name}' already exists."
)
sql_user = UserTable(name=user_name)
user = User(**sql_user.dict())
session.add(sql_user)
session.commit()
return user
def _delete_user(self, user_name: str) -> None:
"""Deletes a user.
Args:
user_name: Name of the user to delete.
Raises:
KeyError: If no user with the given name exists.
"""
with Session(self.engine) as session:
try:
user = session.exec(
select(UserTable).where(UserTable.name == user_name)
).one()
except NoResultFound as error:
raise KeyError from error
session.delete(user)
session.commit()
self._delete_query_results(
select(RoleAssignmentTable).where(
RoleAssignmentTable.user_id == user.id
)
)
self._delete_query_results(
select(TeamAssignmentTable).where(
TeamAssignmentTable.user_id == user.id
)
)
@property
def teams(self) -> List[Team]:
"""All registered teams.
Returns:
A list of all registered teams.
"""
with Session(self.engine) as session:
return [
Team(**team.dict())
for team in session.exec(select(TeamTable)).all()
]
def _get_team(self, team_name: str) -> Team:
"""Gets a specific team.
Args:
team_name: Name of the team to get.
Returns:
The requested team.
Raises:
KeyError: If no team with the given name exists.
"""
with Session(self.engine) as session:
try:
team = session.exec(
select(TeamTable).where(TeamTable.name == team_name)
).one()
except NoResultFound as error:
raise KeyError from error
return Team(**team.dict())
def _create_team(self, team_name: str) -> Team:
"""Creates a new team.
Args:
team_name: Unique team name.
Returns:
The newly created team.
Raises:
EntityExistsError: If a team with the given name already exists.
"""
with Session(self.engine) as session:
existing_team = session.exec(
select(TeamTable).where(TeamTable.name == team_name)
).first()
if existing_team:
raise EntityExistsError(
f"Team with name '{team_name}' already exists."
)
sql_team = TeamTable(name=team_name)
team = Team(**sql_team.dict())
session.add(sql_team)
session.commit()
return team
def _delete_team(self, team_name: str) -> None:
"""Deletes a team.
Args:
team_name: Name of the team to delete.
Raises:
KeyError: If no team with the given name exists.
"""
with Session(self.engine) as session:
try:
team = session.exec(
select(TeamTable).where(TeamTable.name == team_name)
).one()
except NoResultFound as error:
raise KeyError from error
session.delete(team)
session.commit()
self._delete_query_results(
select(RoleAssignmentTable).where(
RoleAssignmentTable.team_id == team.id
)
)
self._delete_query_results(
select(TeamAssignmentTable).where(
TeamAssignmentTable.team_id == team.id
)
)
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
with Session(self.engine) as session:
try:
team = session.exec(
select(TeamTable).where(TeamTable.name == team_name)
).one()
user = session.exec(
select(UserTable).where(UserTable.name == user_name)
).one()
except NoResultFound as error:
raise KeyError from error
assignment = TeamAssignmentTable(user_id=user.id, team_id=team.id)
session.add(assignment)
session.commit()
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
with Session(self.engine) as session:
try:
assignment = session.exec(
select(TeamAssignmentTable)
.where(TeamAssignmentTable.team_id == TeamTable.id)
.where(TeamAssignmentTable.user_id == UserTable.id)
.where(UserTable.name == user_name)
.where(TeamTable.name == team_name)
).one()
except NoResultFound as error:
raise KeyError from error
session.delete(assignment)
session.commit()
@property
def projects(self) -> List[Project]:
"""All registered projects.
Returns:
A list of all registered projects.
"""
with Session(self.engine) as session:
return [
Project(**project.dict())
for project in session.exec(select(ProjectTable)).all()
]
def _get_project(self, project_name: str) -> Project:
"""Get an existing project by name.
Args:
project_name: Name of the project to get.
Returns:
The requested project if one was found.
Raises:
KeyError: If there is no such project.
"""
with Session(self.engine) as session:
try:
project = session.exec(
select(ProjectTable).where(
ProjectTable.name == project_name
)
).one()
except NoResultFound as error:
raise KeyError from error
return Project(**project.dict())
def _create_project(
self, project_name: str, description: Optional[str] = None
) -> Project:
"""Creates a new project.
Args:
project_name: Unique project name.
description: Optional project description.
Returns:
The newly created project.
Raises:
EntityExistsError: If a project with the given name already exists.
"""
with Session(self.engine) as session:
existing_project = session.exec(
select(ProjectTable).where(ProjectTable.name == project_name)
).first()
if existing_project:
raise EntityExistsError(
f"Project with name '{project_name}' already exists."
)
sql_project = ProjectTable(name=project_name)
project = Project(**sql_project.dict())
session.add(sql_project)
session.commit()
return project
def _delete_project(self, project_name: str) -> None:
"""Deletes a project.
Args:
project_name: Name of the project to delete.
Raises:
KeyError: If no project with the given name exists.
"""
with Session(self.engine) as session:
try:
project = session.exec(
select(ProjectTable).where(
ProjectTable.name == project_name
)
).one()
except NoResultFound as error:
raise KeyError from error
session.delete(project)
session.commit()
self._delete_query_results(
select(RoleAssignmentTable).where(
RoleAssignmentTable.project_id == project.id
)
)
@property
def roles(self) -> List[Role]:
"""All registered roles.
Returns:
A list of all registered roles.
"""
with Session(self.engine) as session:
return [
Role(**role.dict())
for role in session.exec(select(RoleTable)).all()
]
@property
def role_assignments(self) -> List[RoleAssignment]:
"""All registered role assignments.
Returns:
A list of all registered role assignments.
"""
with Session(self.engine) as session:
return [
RoleAssignment(**assignment.dict())
for assignment in session.exec(
select(RoleAssignmentTable)
).all()
]
def _get_role(self, role_name: str) -> Role:
"""Gets a specific role.
Args:
role_name: Name of the role to get.
Returns:
The requested role.
Raises:
KeyError: If no role with the given name exists.
"""
with Session(self.engine) as session:
try:
role = session.exec(
select(RoleTable).where(RoleTable.name == role_name)
).one()
except NoResultFound as error:
raise KeyError from error
return Role(**role.dict())
def _create_role(self, role_name: str) -> Role:
"""Creates a new role.
Args:
role_name: Unique role name.
Returns:
The newly created role.
Raises:
EntityExistsError: If a role with the given name already exists.
"""
with Session(self.engine) as session:
existing_role = session.exec(
select(RoleTable).where(RoleTable.name == role_name)
).first()
if existing_role:
raise EntityExistsError(
f"Role with name '{role_name}' already exists."
)
sql_role = RoleTable(name=role_name)
role = Role(**sql_role.dict())
session.add(sql_role)
session.commit()
return role
def _delete_role(self, role_name: str) -> None:
"""Deletes a role.
Args:
role_name: Name of the role to delete.
Raises:
KeyError: If no role with the given name exists.
"""
with Session(self.engine) as session:
try:
role = session.exec(
select(RoleTable).where(RoleTable.name == role_name)
).one()
except NoResultFound as error:
raise KeyError from error
session.delete(role)
session.commit()
self._delete_query_results(
select(RoleAssignmentTable).where(
RoleAssignmentTable.role_id == role.id
)
)
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
with Session(self.engine) as session:
user_id: Optional[UUID] = None
team_id: Optional[UUID] = None
project_id: Optional[UUID] = None
try:
role_id = session.exec(
select(RoleTable.id).where(RoleTable.name == role_name)
).one()
if project_name:
project_id = session.exec(
select(ProjectTable.id).where(
ProjectTable.name == project_name
)
).one()
if is_user:
user_id = session.exec(
select(UserTable.id).where(
UserTable.name == entity_name
)
).one()
else:
team_id = session.exec(
select(TeamTable.id).where(
TeamTable.name == entity_name
)
).one()
except NoResultFound as error:
raise KeyError from error
assignment = RoleAssignmentTable(
role_id=role_id,
project_id=project_id,
user_id=user_id,
team_id=team_id,
)
session.add(assignment)
session.commit()
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
with Session(self.engine) as session:
statement = (
select(RoleAssignmentTable)
.where(RoleAssignmentTable.role_id == RoleTable.id)
.where(RoleTable.name == role_name)
)
if project_name:
statement = statement.where(
RoleAssignmentTable.project_id == ProjectTable.id
).where(ProjectTable.name == project_name)
if is_user:
statement = statement.where(
RoleAssignmentTable.user_id == UserTable.id
).where(UserTable.name == entity_name)
else:
statement = statement.where(
RoleAssignmentTable.team_id == TeamTable.id
).where(TeamTable.name == entity_name)
try:
assignment = session.exec(statement).one()
except NoResultFound as error:
raise KeyError from error
session.delete(assignment)
session.commit()
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
"""
with Session(self.engine) as session:
try:
team_id = session.exec(
select(TeamTable.id).where(TeamTable.name == team_name)
).one()
except NoResultFound as error:
raise KeyError from error
users = session.exec(
select(UserTable)
.where(UserTable.id == TeamAssignmentTable.user_id)
.where(TeamAssignmentTable.team_id == team_id)
).all()
return [User(**user.dict()) for user in users]
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
"""
with Session(self.engine) as session:
try:
user_id = session.exec(
select(UserTable.id).where(UserTable.name == user_name)
).one()
except NoResultFound as error:
raise KeyError from error
teams = session.exec(
select(TeamTable)
.where(TeamTable.id == TeamAssignmentTable.team_id)
.where(TeamAssignmentTable.user_id == user_id)
).all()
return [Team(**team.dict()) for team in teams]
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
"""
with Session(self.engine) as session:
try:
user_id = session.exec(
select(UserTable.id).where(UserTable.name == user_name)
).one()
statement = select(RoleAssignmentTable).where(
RoleAssignmentTable.user_id == user_id
)
if project_name:
project_id = session.exec(
select(ProjectTable.id).where(
ProjectTable.name == project_name
)
).one()
statement = statement.where(
RoleAssignmentTable.project_id == project_id
)
except NoResultFound as error:
raise KeyError from error
assignments = [
RoleAssignment(**assignment.dict())
for assignment in session.exec(statement).all()
]
if include_team_roles:
for team in self.get_teams_for_user(user_name):
assignments += self.get_role_assignments_for_team(
team.name, project_name=project_name
)
return assignments
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no team or project with the given names exists.
"""
with Session(self.engine) as session:
try:
team_id = session.exec(
select(TeamTable.id).where(TeamTable.name == team_name)
).one()
statement = select(RoleAssignmentTable).where(
RoleAssignmentTable.team_id == team_id
)
if project_name:
project_id = session.exec(
select(ProjectTable.id).where(
ProjectTable.name == project_name
)
).one()
statement = statement.where(
RoleAssignmentTable.project_id == project_id
)
except NoResultFound as error:
raise KeyError from error
return [
RoleAssignment(**assignment.dict())
for assignment in session.exec(statement).all()
]
# Pipelines and pipeline runs
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
with Session(self.engine) as session:
try:
statement = (
select(PipelineRunTable)
.where(PipelineRunTable.name == run_name)
.where(PipelineRunTable.pipeline_name == pipeline_name)
)
if project_name:
statement = statement.where(
PipelineRunTable.project_name == project_name
)
run = session.exec(statement).one()
return run.to_pipeline_run_wrapper()
except NoResultFound as error:
raise KeyError from error
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
with Session(self.engine) as session:
try:
statement = select(PipelineRunTable).where(
PipelineRunTable.pipeline_name == pipeline_name
)
if project_name:
statement = statement.where(
PipelineRunTable.project_name == project_name
)
return [
run.to_pipeline_run_wrapper()
for run in session.exec(statement).all()
]
except NoResultFound as error:
raise KeyError from error
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
with Session(self.engine) as session:
existing_run = session.exec(
select(PipelineRunTable).where(
PipelineRunTable.name == pipeline_run.name
)
).first()
if existing_run:
raise EntityExistsError(
f"Pipeline run with name '{pipeline_run.name}' already"
"exists. Please make sure your pipeline run names are "
"unique."
)
sql_run = PipelineRunTable.from_pipeline_run_wrapper(pipeline_run)
session.add(sql_run)
session.commit()
# Handling stack component flavors
@property
def flavors(self) -> List[FlavorWrapper]:
"""All registered flavors.
Returns:
A list of all registered flavors.
"""
with Session(self.engine) as session:
return [
FlavorWrapper(**flavor.dict())
for flavor in session.exec(select(ZenFlavor)).all()
]
def _create_flavor(
self,
source: str,
name: str,
stack_component_type: StackComponentType,
) -> FlavorWrapper:
"""Creates a new flavor.
Args:
source: the source path to the implemented flavor.
name: the name of the flavor.
stack_component_type: the corresponding StackComponentType.
integration: the name of the integration.
Returns:
The newly created flavor.
Raises:
EntityExistsError: If a flavor with the given name and type
already exists.
"""
with Session(self.engine) as session:
existing_flavor = session.exec(
select(ZenFlavor).where(
ZenFlavor.name == name,
ZenFlavor.type == stack_component_type,
)
).first()
if existing_flavor:
raise EntityExistsError(
f"A {stack_component_type} with '{name}' flavor already "
f"exists."
)
sql_flavor = ZenFlavor(
name=name,
source=source,
type=stack_component_type,
)
flavor_wrapper = FlavorWrapper(**sql_flavor.dict())
session.add(sql_flavor)
session.commit()
return flavor_wrapper
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
with Session(self.engine) as session:
flavors = session.exec(
select(ZenFlavor).where(ZenFlavor.type == component_type)
).all()
return [
FlavorWrapper(
name=f.name,
source=f.source,
type=f.type,
integration=f.integration,
)
for f in flavors
]
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
with Session(self.engine) as session:
try:
flavor = session.exec(
select(ZenFlavor).where(
ZenFlavor.name == flavor_name,
ZenFlavor.type == component_type,
)
).one()
return FlavorWrapper(
name=flavor.name,
source=flavor.source,
type=flavor.type,
integration=flavor.integration,
)
except NoResultFound as error:
raise KeyError from error
# Implementation-specific internal methods:
@property
def stack_names(self) -> List[str]:
"""Names of all stacks registered in this ZenStore."""
with Session(self.engine) as session:
return [s.name for s in session.exec(select(ZenStack))]
def _delete_query_results(self, query: Any) -> None:
"""Deletes all rows returned by the input query."""
with Session(self.engine) as session:
for result in session.exec(query).all():
session.delete(result)
session.commit()
flavors: List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]
property
readonly
All registered flavors.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
A list of all registered flavors. |
projects: List[zenml.zen_stores.models.user_management_models.Project]
property
readonly
All registered projects.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Project] |
A list of all registered projects. |
role_assignments: List[zenml.zen_stores.models.user_management_models.RoleAssignment]
property
readonly
All registered role assignments.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
A list of all registered role assignments. |
roles: List[zenml.zen_stores.models.user_management_models.Role]
property
readonly
All registered roles.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Role] |
A list of all registered roles. |
stack_configurations: Dict[str, Dict[zenml.enums.StackComponentType, str]]
property
readonly
Configuration for all stacks registered in this zen store.
Returns:
Type | Description |
---|---|
Dict[str, Dict[zenml.enums.StackComponentType, str]] |
Dictionary mapping stack names to Dict[StackComponentType, str] |
stack_names: List[str]
property
readonly
Names of all stacks registered in this ZenStore.
stacks_empty: bool
property
readonly
Check if the zen store is empty.
teams: List[zenml.zen_stores.models.user_management_models.Team]
property
readonly
All registered teams.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
A list of all registered teams. |
type: StoreType
property
readonly
The type of zen store.
url: str
property
readonly
URL of the repository.
users: List[zenml.zen_stores.models.user_management_models.User]
property
readonly
All registered users.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
A list of all registered users. |
add_user_to_team(self, team_name, user_name)
Adds a user to a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def add_user_to_team(self, team_name: str, user_name: str) -> None:
"""Adds a user to a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
with Session(self.engine) as session:
try:
team = session.exec(
select(TeamTable).where(TeamTable.name == team_name)
).one()
user = session.exec(
select(UserTable).where(UserTable.name == user_name)
).one()
except NoResultFound as error:
raise KeyError from error
assignment = TeamAssignmentTable(user_id=user.id, team_id=team.id)
session.add(assignment)
session.commit()
assign_role(self, role_name, entity_name, project_name=None, is_user=True)
Assigns a role to a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to assign. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def assign_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Assigns a role to a user or team.
Args:
role_name: Name of the role to assign.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
with Session(self.engine) as session:
user_id: Optional[UUID] = None
team_id: Optional[UUID] = None
project_id: Optional[UUID] = None
try:
role_id = session.exec(
select(RoleTable.id).where(RoleTable.name == role_name)
).one()
if project_name:
project_id = session.exec(
select(ProjectTable.id).where(
ProjectTable.name == project_name
)
).one()
if is_user:
user_id = session.exec(
select(UserTable.id).where(
UserTable.name == entity_name
)
).one()
else:
team_id = session.exec(
select(TeamTable.id).where(
TeamTable.name == entity_name
)
).one()
except NoResultFound as error:
raise KeyError from error
assignment = RoleAssignmentTable(
role_id=role_id,
project_id=project_id,
user_id=user_id,
team_id=team_id,
)
session.add(assignment)
session.commit()
get_flavor_by_name_and_type(self, flavor_name, component_type)
Fetch a flavor by a given name and type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor_name |
str |
The name of the flavor. |
required |
component_type |
StackComponentType |
Optional, the type of the component. |
required |
Returns:
Type | Description |
---|---|
FlavorWrapper |
Flavor instance if it exists |
Exceptions:
Type | Description |
---|---|
KeyError |
If no flavor exists with the given name and type or there are more than one instances |
Source code in zenml/zen_stores/sql_zen_store.py
def get_flavor_by_name_and_type(
self,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorWrapper:
"""Fetch a flavor by a given name and type.
Args:
flavor_name: The name of the flavor.
component_type: Optional, the type of the component.
Returns:
Flavor instance if it exists
Raises:
KeyError: If no flavor exists with the given name and type
or there are more than one instances
"""
with Session(self.engine) as session:
try:
flavor = session.exec(
select(ZenFlavor).where(
ZenFlavor.name == flavor_name,
ZenFlavor.type == component_type,
)
).one()
return FlavorWrapper(
name=flavor.name,
source=flavor.source,
type=flavor.type,
integration=flavor.integration,
)
except NoResultFound as error:
raise KeyError from error
get_flavors_by_type(self, component_type)
Fetch all flavor defined for a specific stack component type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the stack component. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of all the flavors for the given stack component type. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_flavors_by_type(
self, component_type: StackComponentType
) -> List[FlavorWrapper]:
"""Fetch all flavor defined for a specific stack component type.
Args:
component_type: The type of the stack component.
Returns:
List of all the flavors for the given stack component type.
"""
with Session(self.engine) as session:
flavors = session.exec(
select(ZenFlavor).where(ZenFlavor.type == component_type)
).all()
return [
FlavorWrapper(
name=f.name,
source=f.source,
type=f.type,
integration=f.integration,
)
for f in flavors
]
get_local_url(path)
staticmethod
Get a local SQL url for a given local path.
Source code in zenml/zen_stores/sql_zen_store.py
@staticmethod
def get_local_url(path: str) -> str:
"""Get a local SQL url for a given local path."""
return f"sqlite:///{path}/zenml.db"
get_path_from_url(url)
staticmethod
Get the local path from a URL, if it points to a local sqlite file.
This method first checks that the URL is a valid SQLite URL, which is backed by a file in the local filesystem. All other types of supported SQLAlchemy connection URLs are considered non-local and won't return a valid local path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL to get the path from. |
required |
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
The path extracted from the URL, or None, if the URL does not point to a local sqlite file. |
Source code in zenml/zen_stores/sql_zen_store.py
@staticmethod
def get_path_from_url(url: str) -> Optional[Path]:
"""Get the local path from a URL, if it points to a local sqlite file.
This method first checks that the URL is a valid SQLite URL, which is
backed by a file in the local filesystem. All other types of supported
SQLAlchemy connection URLs are considered non-local and won't return
a valid local path.
Args:
url: The URL to get the path from.
Returns:
The path extracted from the URL, or None, if the URL does not
point to a local sqlite file.
"""
if not SqlZenStore.is_valid_url(url):
raise ValueError(f"Invalid URL for SQL store: {url}")
if not url.startswith("sqlite:///"):
return None
url = url.replace("sqlite:///", "")
return Path(url)
get_pipeline_run(self, pipeline_name, run_name, project_name=None)
Gets a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get the run. |
required |
run_name |
str |
Name of the pipeline run to get. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline run. |
None |
Exceptions:
Type | Description |
---|---|
KeyError |
If no pipeline run (or project) with the given name exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_pipeline_run(
self,
pipeline_name: str,
run_name: str,
project_name: Optional[str] = None,
) -> PipelineRunWrapper:
"""Gets a pipeline run.
Args:
pipeline_name: Name of the pipeline for which to get the run.
run_name: Name of the pipeline run to get.
project_name: Optional name of the project from which to get the
pipeline run.
Raises:
KeyError: If no pipeline run (or project) with the given name
exists.
"""
with Session(self.engine) as session:
try:
statement = (
select(PipelineRunTable)
.where(PipelineRunTable.name == run_name)
.where(PipelineRunTable.pipeline_name == pipeline_name)
)
if project_name:
statement = statement.where(
PipelineRunTable.project_name == project_name
)
run = session.exec(statement).one()
return run.to_pipeline_run_wrapper()
except NoResultFound as error:
raise KeyError from error
get_pipeline_runs(self, pipeline_name, project_name=None)
Gets pipeline runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline for which to get runs. |
required |
project_name |
Optional[str] |
Optional name of the project from which to get the pipeline runs. |
None |
Source code in zenml/zen_stores/sql_zen_store.py
def get_pipeline_runs(
self, pipeline_name: str, project_name: Optional[str] = None
) -> List[PipelineRunWrapper]:
"""Gets pipeline runs.
Args:
pipeline_name: Name of the pipeline for which to get runs.
project_name: Optional name of the project from which to get the
pipeline runs.
"""
with Session(self.engine) as session:
try:
statement = select(PipelineRunTable).where(
PipelineRunTable.pipeline_name == pipeline_name
)
if project_name:
statement = statement.where(
PipelineRunTable.project_name == project_name
)
return [
run.to_pipeline_run_wrapper()
for run in session.exec(statement).all()
]
except NoResultFound as error:
raise KeyError from error
get_role_assignments_for_team(self, team_name, project_name=None)
Fetches all role assignments for a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team or project with the given names exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_role_assignments_for_team(
self,
team_name: str,
project_name: Optional[str] = None,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a team.
Args:
team_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
Returns:
List of role assignments for this team.
Raises:
KeyError: If no team or project with the given names exists.
"""
with Session(self.engine) as session:
try:
team_id = session.exec(
select(TeamTable.id).where(TeamTable.name == team_name)
).one()
statement = select(RoleAssignmentTable).where(
RoleAssignmentTable.team_id == team_id
)
if project_name:
project_id = session.exec(
select(ProjectTable.id).where(
ProjectTable.name == project_name
)
).one()
statement = statement.where(
RoleAssignmentTable.project_id == project_id
)
except NoResultFound as error:
raise KeyError from error
return [
RoleAssignment(**assignment.dict())
for assignment in session.exec(statement).all()
]
get_role_assignments_for_user(self, user_name, project_name=None, include_team_roles=True)
Fetches all role assignments for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
project_name |
Optional[str] |
Optional filter to only return roles assigned for this project. |
None |
include_team_roles |
bool |
If |
True |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.RoleAssignment] |
List of role assignments for this user. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user or project with the given names exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_role_assignments_for_user(
self,
user_name: str,
project_name: Optional[str] = None,
include_team_roles: bool = True,
) -> List[RoleAssignment]:
"""Fetches all role assignments for a user.
Args:
user_name: Name of the user.
project_name: Optional filter to only return roles assigned for
this project.
include_team_roles: If `True`, includes roles for all teams that
the user is part of.
Returns:
List of role assignments for this user.
Raises:
KeyError: If no user or project with the given names exists.
"""
with Session(self.engine) as session:
try:
user_id = session.exec(
select(UserTable.id).where(UserTable.name == user_name)
).one()
statement = select(RoleAssignmentTable).where(
RoleAssignmentTable.user_id == user_id
)
if project_name:
project_id = session.exec(
select(ProjectTable.id).where(
ProjectTable.name == project_name
)
).one()
statement = statement.where(
RoleAssignmentTable.project_id == project_id
)
except NoResultFound as error:
raise KeyError from error
assignments = [
RoleAssignment(**assignment.dict())
for assignment in session.exec(statement).all()
]
if include_team_roles:
for team in self.get_teams_for_user(user_name):
assignments += self.get_role_assignments_for_team(
team.name, project_name=project_name
)
return assignments
get_stack_configuration(self, name)
Fetches a stack configuration by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the stack to fetch. |
required |
Returns:
Type | Description |
---|---|
Dict[zenml.enums.StackComponentType, str] |
Dict[StackComponentType, str] for the requested stack name. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no stack exists for the given name. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
"""Fetches a stack configuration by name.
Args:
name: The name of the stack to fetch.
Returns:
Dict[StackComponentType, str] for the requested stack name.
Raises:
KeyError: If no stack exists for the given name.
"""
logger.debug("Fetching stack with name '%s'.", name)
# first check that the stack exists
with Session(self.engine) as session:
maybe_stack = session.exec(
select(ZenStack).where(ZenStack.name == name)
).first()
if maybe_stack is None:
raise KeyError(
f"Unable to find stack with name '{name}'. Available names: "
f"{set(self.stack_names)}."
)
# then get all components assigned to that stack
with Session(self.engine) as session:
definitions_and_components = session.exec(
select(ZenStackDefinition, ZenStackComponent)
.where(
ZenStackDefinition.component_type
== ZenStackComponent.component_type
)
.where(
ZenStackDefinition.component_name == ZenStackComponent.name
)
.where(ZenStackDefinition.stack_name == name)
)
params = {
component.component_type: component.name
for _, component in definitions_and_components
}
return {StackComponentType(typ): name for typ, name in params.items()}
get_teams_for_user(self, user_name)
Fetches all teams for a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name |
str |
Name of the user. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.Team] |
List of teams that the user is part of. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user with the given name exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_teams_for_user(self, user_name: str) -> List[Team]:
"""Fetches all teams for a user.
Args:
user_name: Name of the user.
Returns:
List of teams that the user is part of.
Raises:
KeyError: If no user with the given name exists.
"""
with Session(self.engine) as session:
try:
user_id = session.exec(
select(UserTable.id).where(UserTable.name == user_name)
).one()
except NoResultFound as error:
raise KeyError from error
teams = session.exec(
select(TeamTable)
.where(TeamTable.id == TeamAssignmentTable.team_id)
.where(TeamAssignmentTable.user_id == user_id)
).all()
return [Team(**team.dict()) for team in teams]
get_users_for_team(self, team_name)
Fetches all users of a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.user_management_models.User] |
List of users that are part of the team. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no team with the given name exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def get_users_for_team(self, team_name: str) -> List[User]:
"""Fetches all users of a team.
Args:
team_name: Name of the team.
Returns:
List of users that are part of the team.
Raises:
KeyError: If no team with the given name exists.
"""
with Session(self.engine) as session:
try:
team_id = session.exec(
select(TeamTable.id).where(TeamTable.name == team_name)
).one()
except NoResultFound as error:
raise KeyError from error
users = session.exec(
select(UserTable)
.where(UserTable.id == TeamAssignmentTable.user_id)
.where(TeamAssignmentTable.team_id == team_id)
).all()
return [User(**user.dict()) for user in users]
initialize(self, url, *args, **kwargs)
Initialize a new SqlZenStore.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
odbc path to a database. |
required |
args, |
kwargs |
additional parameters for SQLModel. |
required |
Returns:
Type | Description |
---|---|
SqlZenStore |
The initialized zen store instance. |
Source code in zenml/zen_stores/sql_zen_store.py
def initialize(
self,
url: str,
*args: Any,
**kwargs: Any,
) -> "SqlZenStore":
"""Initialize a new SqlZenStore.
Args:
url: odbc path to a database.
args, kwargs: additional parameters for SQLModel.
Returns:
The initialized zen store instance.
"""
if not self.is_valid_url(url):
raise ValueError(f"Invalid URL for SQL store: {url}")
logger.debug("Initializing SqlZenStore at %s", url)
self._url = url
local_path = self.get_path_from_url(url)
if local_path:
utils.create_dir_recursive_if_not_exists(str(local_path.parent))
# we need to remove `skip_default_registrations` from the kwargs,
# because SQLModel will raise an error if it is present
sql_kwargs = kwargs.copy()
sql_kwargs.pop("skip_default_registrations", False)
sql_kwargs.pop("track_analytics", False)
sql_kwargs.pop("skip_migration", False)
self.engine = create_engine(url, *args, **sql_kwargs)
SQLModel.metadata.create_all(self.engine)
with Session(self.engine) as session:
if not session.exec(select(ZenUser)).first():
session.add(ZenUser(id=1, name="LocalZenUser"))
session.commit()
super().initialize(url, *args, **kwargs)
return self
is_valid_url(url)
staticmethod
Check if the given url is a valid SQL url.
Source code in zenml/zen_stores/sql_zen_store.py
@staticmethod
def is_valid_url(url: str) -> bool:
"""Check if the given url is a valid SQL url."""
try:
make_url(url)
except ArgumentError:
logger.debug("Invalid SQL URL: %s", url)
return False
return True
register_pipeline_run(self, pipeline_run)
Registers a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunWrapper |
The pipeline run to register. |
required |
Exceptions:
Type | Description |
---|---|
EntityExistsError |
If a pipeline run with the same name already exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def register_pipeline_run(
self,
pipeline_run: PipelineRunWrapper,
) -> None:
"""Registers a pipeline run.
Args:
pipeline_run: The pipeline run to register.
Raises:
EntityExistsError: If a pipeline run with the same name already
exists.
"""
with Session(self.engine) as session:
existing_run = session.exec(
select(PipelineRunTable).where(
PipelineRunTable.name == pipeline_run.name
)
).first()
if existing_run:
raise EntityExistsError(
f"Pipeline run with name '{pipeline_run.name}' already"
"exists. Please make sure your pipeline run names are "
"unique."
)
sql_run = PipelineRunTable.from_pipeline_run_wrapper(pipeline_run)
session.add(sql_run)
session.commit()
remove_user_from_team(self, team_name, user_name)
Removes a user from a team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
team_name |
str |
Name of the team. |
required |
user_name |
str |
Name of the user. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If no user and team with the given names exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def remove_user_from_team(self, team_name: str, user_name: str) -> None:
"""Removes a user from a team.
Args:
team_name: Name of the team.
user_name: Name of the user.
Raises:
KeyError: If no user and team with the given names exists.
"""
with Session(self.engine) as session:
try:
assignment = session.exec(
select(TeamAssignmentTable)
.where(TeamAssignmentTable.team_id == TeamTable.id)
.where(TeamAssignmentTable.user_id == UserTable.id)
.where(UserTable.name == user_name)
.where(TeamTable.name == team_name)
).one()
except NoResultFound as error:
raise KeyError from error
session.delete(assignment)
session.commit()
revoke_role(self, role_name, entity_name, project_name=None, is_user=True)
Revokes a role from a user or team.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
role_name |
str |
Name of the role to revoke. |
required |
entity_name |
str |
User or team name. |
required |
project_name |
Optional[str] |
Optional project name. |
None |
is_user |
bool |
Boolean indicating whether the given |
True |
Exceptions:
Type | Description |
---|---|
KeyError |
If no role, entity or project with the given names exists. |
Source code in zenml/zen_stores/sql_zen_store.py
def revoke_role(
self,
role_name: str,
entity_name: str,
project_name: Optional[str] = None,
is_user: bool = True,
) -> None:
"""Revokes a role from a user or team.
Args:
role_name: Name of the role to revoke.
entity_name: User or team name.
project_name: Optional project name.
is_user: Boolean indicating whether the given `entity_name` refers
to a user.
Raises:
KeyError: If no role, entity or project with the given names exists.
"""
with Session(self.engine) as session:
statement = (
select(RoleAssignmentTable)
.where(RoleAssignmentTable.role_id == RoleTable.id)
.where(RoleTable.name == role_name)
)
if project_name:
statement = statement.where(
RoleAssignmentTable.project_id == ProjectTable.id
).where(ProjectTable.name == project_name)
if is_user:
statement = statement.where(
RoleAssignmentTable.user_id == UserTable.id
).where(UserTable.name == entity_name)
else:
statement = statement.where(
RoleAssignmentTable.team_id == TeamTable.id
).where(TeamTable.name == entity_name)
try:
assignment = session.exec(statement).one()
except NoResultFound as error:
raise KeyError from error
session.delete(assignment)
session.commit()
ZenStackDefinition (SQLModel)
pydantic-model
Join table between Stacks and StackComponents
Source code in zenml/zen_stores/sql_zen_store.py
class ZenStackDefinition(SQLModel, table=True):
"""Join table between Stacks and StackComponents"""
stack_name: str = Field(primary_key=True, foreign_key="zenstack.name")
component_type: StackComponentType = Field(
primary_key=True, foreign_key="zenstackcomponent.component_type"
)
component_name: str = Field(
primary_key=True, foreign_key="zenstackcomponent.name"
)