Core
zenml.core
special
The core
module is where all the base ZenML functionality is defined,
including a Pydantic base class for components, a git wrapper and a class for ZenML's own
repository methods.
This module is also where the local service functionality (which keeps track of all the ZenML
components) is defined. Every ZenML project has its own ZenML repository, and
the repo
module is where associated methods are defined. The
repo.init_repo
method is where all our functionality is kickstarted
when you first initialize everything through the `zenml init
CLI command.
base_component
BaseComponent (BaseSettings)
pydantic-model
Class definition for the base config.
The base component class defines the basic serialization / deserialization of various components used in ZenML. The logic of the serialization / deserialization is as follows:
- If a
uuid
is passed in, then the object is read from a file, so theconstructor becomes a query for an object that is assumed to already been serialized. - If a 'uuid` is NOT passed, then a new object is created with the default args (and any other args that are passed), and therefore a fresh serialization takes place.
Source code in zenml/core/base_component.py
class BaseComponent(BaseSettings):
"""Class definition for the base config.
The base component class defines the basic serialization / deserialization
of various components used in ZenML. The logic of the serialization /
deserialization is as follows:
* If a `uuid` is passed in, then the object is read from a file, so
theconstructor becomes a query for an object that is assumed to already
been serialized.
* If a 'uuid` is NOT passed, then a new object is created with the default
args (and any other args that are passed), and therefore a fresh
serialization takes place.
"""
uuid: Optional[UUID] = Field(default_factory=uuid4)
_file_suffix = ".json"
_superfluous_options: Dict[str, Any] = {}
_serialization_dir: str
def __init__(self, serialization_dir: str, **values: Any):
# Here, we insert monkey patch the `customise_sources` function
# because we want to dynamically generate the serialization
# file path and name.
if hasattr(self, "uuid"):
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
serialization_dir,
self.get_serialization_file_name(),
)
elif "uuid" in values:
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
serialization_dir,
f"{str(values['uuid'])}{self._file_suffix}",
)
else:
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
serialization_dir,
self.get_serialization_file_name(),
)
# Initialize values from the above sources.
super().__init__(**values)
self._serialization_dir = serialization_dir
self._save_backup_file_if_required()
def _save_backup_file_if_required(self) -> None:
"""Saves a backup of the config file if the schema changed."""
if self._superfluous_options:
logger.warning(
"Found superfluous configuration values for class `%s`: %s",
self.__class__.__name__,
set(self._superfluous_options),
)
config_path = self.get_serialization_full_path()
if fileio.file_exists(config_path):
backup_path = config_path + ".backup"
fileio.copy(config_path, backup_path, overwrite=True)
logger.warning(
"Saving backup configuration to '%s'.", backup_path
)
# save the updated file without the extra options
self.update()
def _dump(self) -> None:
"""Dumps all current values to the serialization file."""
self._create_serialization_file_if_not_exists()
file_path = self.get_serialization_full_path()
file_content = self.json(
indent=2,
sort_keys=True,
exclude={SUPERFLUOUS_OPTIONS_ATTRIBUTE_NAME},
)
zenml.io.utils.write_file_contents_as_string(file_path, file_content)
def dict(self, **kwargs: Any) -> Dict[str, Any]:
"""Removes private attributes from pydantic dict so they don't get
stored in our config files."""
return {
key: value
for key, value in super().dict(**kwargs).items()
if not key.startswith("_")
}
def _create_serialization_file_if_not_exists(self) -> None:
"""Creates the serialization file if it does not exist."""
f = self.get_serialization_full_path()
if not fileio.file_exists(str(f)):
fileio.create_file_if_not_exists(str(f))
def get_serialization_dir(self) -> str:
"""Return the dir where object is serialized."""
return self._serialization_dir
def get_serialization_file_name(self) -> str:
"""Return the name of the file where object is serialized. This
has a sane default in cases where uuid is not passed externally, and
therefore reading from a serialize file is not an option for the table.
However, we still this function to go through without an exception,
therefore the sane default."""
if hasattr(self, "uuid"):
return f"{str(self.uuid)}{self._file_suffix}"
else:
return f"DEFAULT{self._file_suffix}"
def get_serialization_full_path(self) -> str:
"""Returns the full path of the serialization file."""
return os.path.join(
self._serialization_dir, self.get_serialization_file_name()
)
def update(self) -> None:
"""Persist the current state of the component.
Calling this will result in a persistent, stateful change in the
system.
"""
self._dump()
def delete(self) -> None:
"""Deletes the persisted state of this object."""
fileio.remove(self.get_serialization_full_path())
@root_validator(pre=True)
def check_superfluous_options(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Detects superfluous config values (usually read from an existing
config file after the schema changed) and saves them in the classes
`_superfluous_options` attribute."""
field_names = {field.alias for field in cls.__fields__.values()}
superfluous_options: Dict[str, Any] = {}
for key in set(values):
if key not in field_names:
superfluous_options[key] = values.pop(key)
values[SUPERFLUOUS_OPTIONS_ATTRIBUTE_NAME] = superfluous_options
return values
class Config:
"""Configuration of settings."""
arbitrary_types_allowed = True
env_prefix = "zenml_"
# allow extra options so we can detect legacy configuration files
extra = "allow"
Config
Configuration of settings.
Source code in zenml/core/base_component.py
class Config:
"""Configuration of settings."""
arbitrary_types_allowed = True
env_prefix = "zenml_"
# allow extra options so we can detect legacy configuration files
extra = "allow"
__init__(self, serialization_dir, **values)
special
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
Source code in zenml/core/base_component.py
def __init__(self, serialization_dir: str, **values: Any):
# Here, we insert monkey patch the `customise_sources` function
# because we want to dynamically generate the serialization
# file path and name.
if hasattr(self, "uuid"):
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
serialization_dir,
self.get_serialization_file_name(),
)
elif "uuid" in values:
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
serialization_dir,
f"{str(values['uuid'])}{self._file_suffix}",
)
else:
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
serialization_dir,
self.get_serialization_file_name(),
)
# Initialize values from the above sources.
super().__init__(**values)
self._serialization_dir = serialization_dir
self._save_backup_file_if_required()
check_superfluous_options(values)
classmethod
Detects superfluous config values (usually read from an existing
config file after the schema changed) and saves them in the classes
_superfluous_options
attribute.
Source code in zenml/core/base_component.py
@root_validator(pre=True)
def check_superfluous_options(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
"""Detects superfluous config values (usually read from an existing
config file after the schema changed) and saves them in the classes
`_superfluous_options` attribute."""
field_names = {field.alias for field in cls.__fields__.values()}
superfluous_options: Dict[str, Any] = {}
for key in set(values):
if key not in field_names:
superfluous_options[key] = values.pop(key)
values[SUPERFLUOUS_OPTIONS_ATTRIBUTE_NAME] = superfluous_options
return values
delete(self)
Deletes the persisted state of this object.
Source code in zenml/core/base_component.py
def delete(self) -> None:
"""Deletes the persisted state of this object."""
fileio.remove(self.get_serialization_full_path())
dict(self, **kwargs)
Removes private attributes from pydantic dict so they don't get stored in our config files.
Source code in zenml/core/base_component.py
def dict(self, **kwargs: Any) -> Dict[str, Any]:
"""Removes private attributes from pydantic dict so they don't get
stored in our config files."""
return {
key: value
for key, value in super().dict(**kwargs).items()
if not key.startswith("_")
}
get_serialization_dir(self)
Return the dir where object is serialized.
Source code in zenml/core/base_component.py
def get_serialization_dir(self) -> str:
"""Return the dir where object is serialized."""
return self._serialization_dir
get_serialization_file_name(self)
Return the name of the file where object is serialized. This has a sane default in cases where uuid is not passed externally, and therefore reading from a serialize file is not an option for the table. However, we still this function to go through without an exception, therefore the sane default.
Source code in zenml/core/base_component.py
def get_serialization_file_name(self) -> str:
"""Return the name of the file where object is serialized. This
has a sane default in cases where uuid is not passed externally, and
therefore reading from a serialize file is not an option for the table.
However, we still this function to go through without an exception,
therefore the sane default."""
if hasattr(self, "uuid"):
return f"{str(self.uuid)}{self._file_suffix}"
else:
return f"DEFAULT{self._file_suffix}"
get_serialization_full_path(self)
Returns the full path of the serialization file.
Source code in zenml/core/base_component.py
def get_serialization_full_path(self) -> str:
"""Returns the full path of the serialization file."""
return os.path.join(
self._serialization_dir, self.get_serialization_file_name()
)
update(self)
Persist the current state of the component.
Calling this will result in a persistent, stateful change in the system.
Source code in zenml/core/base_component.py
def update(self) -> None:
"""Persist the current state of the component.
Calling this will result in a persistent, stateful change in the
system.
"""
self._dump()
component_factory
Factory to register all components.
ComponentFactory
Definition of ComponentFactory to track all BaseComponent subclasses.
All BaseComponents (including custom ones) are to be registered here.
Source code in zenml/core/component_factory.py
class ComponentFactory:
"""Definition of ComponentFactory to track all BaseComponent subclasses.
All BaseComponents (including custom ones) are to be
registered here.
"""
def __init__(self, name: str):
"""Constructor for the factory.
Args:
name: Unique name for the factory.
"""
self.name = name
self.components: Dict[str, BaseComponentType] = {}
def get_components(self) -> Dict[str, BaseComponentType]:
"""Return all components"""
return self.components
def get_single_component(self, key: str) -> BaseComponentType:
"""Get a registered component from a key."""
if key in self.components:
return self.components[key]
raise KeyError(
f"Type '{key}' does not exist! Available options: "
f"{[str(k) for k in self.components.keys()]}"
)
def get_component_key(self, component: BaseComponentType) -> str:
"""Gets the key of a registered component."""
for k, v in self.components.items():
if v == component:
return k
raise KeyError(
f"Type '{component}' does not exist! Available options: "
f"{[str(v) for v in self.components.values()]}"
)
def register_component(
self, key: str, component: BaseComponentType
) -> None:
"""Registers a single component class for a given key."""
self.components[str(key)] = component
def register(
self, name: str
) -> Callable[[BaseComponentType], BaseComponentType]:
"""Class decorator to register component classes to
the internal registry.
Args:
name: The name of the component.
Returns:
A function which registers the class at this ComponentFactory.
"""
def inner_wrapper(
wrapped_class: BaseComponentType,
) -> BaseComponentType:
"""Inner wrapper for decorator."""
if name in self.components:
logger.debug(
"Executor %s already exists for factory %s, replacing it..",
name,
self.name,
)
self.register_component(name, wrapped_class)
return wrapped_class
return inner_wrapper
__init__(self, name)
special
Constructor for the factory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Unique name for the factory. |
required |
Source code in zenml/core/component_factory.py
def __init__(self, name: str):
"""Constructor for the factory.
Args:
name: Unique name for the factory.
"""
self.name = name
self.components: Dict[str, BaseComponentType] = {}
get_component_key(self, component)
Gets the key of a registered component.
Source code in zenml/core/component_factory.py
def get_component_key(self, component: BaseComponentType) -> str:
"""Gets the key of a registered component."""
for k, v in self.components.items():
if v == component:
return k
raise KeyError(
f"Type '{component}' does not exist! Available options: "
f"{[str(v) for v in self.components.values()]}"
)
get_components(self)
Return all components
Source code in zenml/core/component_factory.py
def get_components(self) -> Dict[str, BaseComponentType]:
"""Return all components"""
return self.components
get_single_component(self, key)
Get a registered component from a key.
Source code in zenml/core/component_factory.py
def get_single_component(self, key: str) -> BaseComponentType:
"""Get a registered component from a key."""
if key in self.components:
return self.components[key]
raise KeyError(
f"Type '{key}' does not exist! Available options: "
f"{[str(k) for k in self.components.keys()]}"
)
register(self, name)
Class decorator to register component classes to the internal registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the component. |
required |
Returns:
Type | Description |
---|---|
Callable[[Type[zenml.core.base_component.BaseComponent]], Type[zenml.core.base_component.BaseComponent]] |
A function which registers the class at this ComponentFactory. |
Source code in zenml/core/component_factory.py
def register(
self, name: str
) -> Callable[[BaseComponentType], BaseComponentType]:
"""Class decorator to register component classes to
the internal registry.
Args:
name: The name of the component.
Returns:
A function which registers the class at this ComponentFactory.
"""
def inner_wrapper(
wrapped_class: BaseComponentType,
) -> BaseComponentType:
"""Inner wrapper for decorator."""
if name in self.components:
logger.debug(
"Executor %s already exists for factory %s, replacing it..",
name,
self.name,
)
self.register_component(name, wrapped_class)
return wrapped_class
return inner_wrapper
register_component(self, key, component)
Registers a single component class for a given key.
Source code in zenml/core/component_factory.py
def register_component(
self, key: str, component: BaseComponentType
) -> None:
"""Registers a single component class for a given key."""
self.components[str(key)] = component
git_wrapper
Wrapper class to handle Git integration
GitWrapper
Wrapper class for Git.
This class is responsible for handling git interactions, primarily handling versioning of different steps in pipelines.
Source code in zenml/core/git_wrapper.py
class GitWrapper:
"""Wrapper class for Git.
This class is responsible for handling git interactions, primarily
handling versioning of different steps in pipelines.
"""
def __init__(self, repo_path: str):
"""
Initialize GitWrapper. Should be initialized by ZenML Repository.
Args:
repo_path:
Raises:
InvalidGitRepositoryError: If repository is not a git repository.
NoSuchPathError: If the repo_path does not exist.
"""
# TODO [ENG-163]: Raise ZenML exceptions here instead.
self.repo_path: str = repo_path
self.git_root_path: str = os.path.join(repo_path, GIT_FOLDER_NAME)
self.git_repo = GitRepo(self.repo_path)
def check_file_committed(self, file_path: str) -> bool:
"""
Checks file is committed. If yes, return True, else False.
Args:
file_path (str): Path to any file within the ZenML repo.
"""
uncommitted_files = [i.a_path for i in self.git_repo.index.diff(None)]
try:
staged_files = [i.a_path for i in self.git_repo.index.diff("HEAD")]
except BadName:
# for Ref 'HEAD' did not resolve to an object
logger.debug("No committed files in the repo. No staged files.")
staged_files = []
# source: https://stackoverflow.com/questions/3801321/
untracked_files = self.git_repo.git.ls_files(
others=True, exclude_standard=True
).split("\n")
for item in uncommitted_files + staged_files + untracked_files:
# These are all changed files
if file_path == item:
return False
return True
def get_current_sha(self) -> str:
"""
Finds the git sha that each file within the module is currently on.
"""
return cast(str, self.git_repo.head.object.hexsha)
def check_module_clean(self, source: str) -> bool:
"""Returns `True` if all files within source's module are committed.
Args:
source: relative module path pointing to a Class.
"""
# Get the module path
module_path = source_utils.get_module_source_from_source(source)
# Get relative path of module because check_file_committed needs that
module_dir = source_utils.get_relative_path_from_module_source(
module_path
)
# Get absolute path of module because fileio.list_dir needs that
mod_abs_dir = source_utils.get_absolute_path_from_module_source(
module_path
)
module_file_names = fileio.list_dir(mod_abs_dir, only_file_names=True)
# Go through each file in module and see if there are uncommitted ones
for file_path in module_file_names:
path = os.path.join(module_dir, file_path)
# if its .gitignored then continue and don't do anything
if len(self.git_repo.ignored(path)) > 0:
continue
if fileio.is_dir(os.path.join(mod_abs_dir, file_path)):
logger.warning(
f"The step {source} is contained inside a module "
f"that "
f"has sub-directories (the sub-directory {file_path} at "
f"{mod_abs_dir}). For now, ZenML supports only a flat "
f"directory structure in which to place Steps. Please make"
f" sure that the Step does not utilize the sub-directory."
)
if not self.check_file_committed(path):
return False
return True
def stash(self) -> None:
"""Wrapper for git stash"""
git = self.git_repo.git
git.stash()
def stash_pop(self) -> None:
"""Wrapper for git stash pop. Only pops if there's something to pop."""
git = self.git_repo.git
if git.stash("list") != "":
git.stash("pop")
def checkout(
self,
sha_or_branch: Optional[str] = None,
directory: Optional[str] = None,
) -> None:
"""Wrapper for git checkout
Args:
sha_or_branch: hex string of len 40 representing git sha OR
name of branch
directory: relative path to directory to scope checkout
"""
# TODO [ENG-164]: Implement exception handling
git = self.git_repo.git
if sha_or_branch is None:
# Checks out directory at sha_or_branch
assert directory is not None
git.checkout("--", directory)
elif directory is not None:
assert sha_or_branch is not None
# Basically discards all changes in directory
git.checkout(sha_or_branch, "--", directory)
else:
# The case where sha_or_branch is not None and directory is None
# In this case, the whole repo is checked out at sha_or_branch
git.checkout(sha_or_branch)
def reset(self, directory: Optional[str] = None) -> None:
"""Wrapper for `git reset HEAD <directory>`.
Args:
directory: Relative path to directory to scope checkout
"""
git = self.git_repo.git
git.reset("HEAD", directory)
def resolve_class_source(self, class_source: str) -> str:
"""Resolves class_source with an optional pin.
Takes source (e.g. this.module.ClassName), and appends relevant
sha to it if the files within `module` are all committed. If even one
file is not committed, then returns `source` unchanged.
Args:
class_source (str): class_source e.g. this.module.Class
"""
if "@" in class_source:
# already pinned
return class_source
if is_standard_source(class_source):
# that means use standard version
return resolve_standard_source(class_source)
# otherwise use Git resolution
if not self.check_module_clean(class_source):
# Return the source path if not clean
logger.warning(
"Found uncommitted file. Pipelines run with this "
"configuration may not be reproducible. Please commit "
"all files in this module and then run the pipeline to "
"ensure reproducibility."
)
return class_source
return class_source + "@" + self.get_current_sha()
def is_valid_source(self, source: str) -> bool:
"""
Checks whether the source_path is valid or not.
Args:
source (str): class_source e.g. this.module.Class[@pin].
"""
try:
self.load_source_path_class(source)
except GitException:
return False
return True
def load_source_path_class(self, source: str) -> Type[Any]:
"""
Loads a Python class from the source.
Args:
source: class_source e.g. this.module.Class[@sha]
"""
source = source.split("@")[0]
pin = source.split("@")[-1]
is_standard = is_standard_pin(pin)
if "@" in source and not is_standard:
logger.debug(
"Pinned step found with git sha. "
"Loading class from git history."
)
module_source = get_module_source_from_source(source)
relative_module_path = get_relative_path_from_module_source(
module_source
)
logger.warning(
"Found source with a pinned sha. Will now checkout "
f"module: {module_source}"
)
# critical step
if not self.check_module_clean(source):
raise GitException(
f"One of the files at {relative_module_path} "
f"is not committed and we "
f"are trying to load that directory from git "
f"history due to a pinned step in the pipeline. "
f"Please commit the file and then run the "
f"pipeline."
)
# Check out the directory at that sha
self.checkout(sha_or_branch=pin, directory=relative_module_path)
# After this point, all exceptions will first undo the above
try:
class_ = source_utils.import_class_by_path(source)
self.reset(relative_module_path)
self.checkout(directory=relative_module_path)
except Exception as e:
self.reset(relative_module_path)
self.checkout(directory=relative_module_path)
raise GitException(
f"A git exception occurred when checking out repository "
f"from git history. Resetting repository to original "
f"state. Original exception: {e}"
)
elif "@" in source and is_standard:
logger.debug(f"Default {APP_NAME} class used. Loading directly.")
# TODO [ENG-165]: Check if ZenML version is installed before loading.
class_ = source_utils.import_class_by_path(source)
else:
logger.debug(
"Unpinned step found with no git sha. Attempting to "
"load class from current repository state."
)
class_ = source_utils.import_class_by_path(source)
return class_
def resolve_class(self, class_: Type[Any]) -> str:
"""Resolves a class into a serializable source string.
Args:
class_: A Python Class reference.
Returns: source_path e.g. this.module.Class[@pin].
"""
class_source = source_utils.resolve_class(class_)
return self.resolve_class_source(class_source)
__init__(self, repo_path)
special
Initialize GitWrapper. Should be initialized by ZenML Repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
repo_path |
str |
required |
Exceptions:
Type | Description |
---|---|
InvalidGitRepositoryError |
If repository is not a git repository. |
NoSuchPathError |
If the repo_path does not exist. |
Source code in zenml/core/git_wrapper.py
def __init__(self, repo_path: str):
"""
Initialize GitWrapper. Should be initialized by ZenML Repository.
Args:
repo_path:
Raises:
InvalidGitRepositoryError: If repository is not a git repository.
NoSuchPathError: If the repo_path does not exist.
"""
# TODO [ENG-163]: Raise ZenML exceptions here instead.
self.repo_path: str = repo_path
self.git_root_path: str = os.path.join(repo_path, GIT_FOLDER_NAME)
self.git_repo = GitRepo(self.repo_path)
check_file_committed(self, file_path)
Checks file is committed. If yes, return True, else False.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to any file within the ZenML repo. |
required |
Source code in zenml/core/git_wrapper.py
def check_file_committed(self, file_path: str) -> bool:
"""
Checks file is committed. If yes, return True, else False.
Args:
file_path (str): Path to any file within the ZenML repo.
"""
uncommitted_files = [i.a_path for i in self.git_repo.index.diff(None)]
try:
staged_files = [i.a_path for i in self.git_repo.index.diff("HEAD")]
except BadName:
# for Ref 'HEAD' did not resolve to an object
logger.debug("No committed files in the repo. No staged files.")
staged_files = []
# source: https://stackoverflow.com/questions/3801321/
untracked_files = self.git_repo.git.ls_files(
others=True, exclude_standard=True
).split("\n")
for item in uncommitted_files + staged_files + untracked_files:
# These are all changed files
if file_path == item:
return False
return True
check_module_clean(self, source)
Returns True
if all files within source's module are committed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
relative module path pointing to a Class. |
required |
Source code in zenml/core/git_wrapper.py
def check_module_clean(self, source: str) -> bool:
"""Returns `True` if all files within source's module are committed.
Args:
source: relative module path pointing to a Class.
"""
# Get the module path
module_path = source_utils.get_module_source_from_source(source)
# Get relative path of module because check_file_committed needs that
module_dir = source_utils.get_relative_path_from_module_source(
module_path
)
# Get absolute path of module because fileio.list_dir needs that
mod_abs_dir = source_utils.get_absolute_path_from_module_source(
module_path
)
module_file_names = fileio.list_dir(mod_abs_dir, only_file_names=True)
# Go through each file in module and see if there are uncommitted ones
for file_path in module_file_names:
path = os.path.join(module_dir, file_path)
# if its .gitignored then continue and don't do anything
if len(self.git_repo.ignored(path)) > 0:
continue
if fileio.is_dir(os.path.join(mod_abs_dir, file_path)):
logger.warning(
f"The step {source} is contained inside a module "
f"that "
f"has sub-directories (the sub-directory {file_path} at "
f"{mod_abs_dir}). For now, ZenML supports only a flat "
f"directory structure in which to place Steps. Please make"
f" sure that the Step does not utilize the sub-directory."
)
if not self.check_file_committed(path):
return False
return True
checkout(self, sha_or_branch=None, directory=None)
Wrapper for git checkout
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sha_or_branch |
Optional[str] |
hex string of len 40 representing git sha OR name of branch |
None |
directory |
Optional[str] |
relative path to directory to scope checkout |
None |
Source code in zenml/core/git_wrapper.py
def checkout(
self,
sha_or_branch: Optional[str] = None,
directory: Optional[str] = None,
) -> None:
"""Wrapper for git checkout
Args:
sha_or_branch: hex string of len 40 representing git sha OR
name of branch
directory: relative path to directory to scope checkout
"""
# TODO [ENG-164]: Implement exception handling
git = self.git_repo.git
if sha_or_branch is None:
# Checks out directory at sha_or_branch
assert directory is not None
git.checkout("--", directory)
elif directory is not None:
assert sha_or_branch is not None
# Basically discards all changes in directory
git.checkout(sha_or_branch, "--", directory)
else:
# The case where sha_or_branch is not None and directory is None
# In this case, the whole repo is checked out at sha_or_branch
git.checkout(sha_or_branch)
get_current_sha(self)
Finds the git sha that each file within the module is currently on.
Source code in zenml/core/git_wrapper.py
def get_current_sha(self) -> str:
"""
Finds the git sha that each file within the module is currently on.
"""
return cast(str, self.git_repo.head.object.hexsha)
is_valid_source(self, source)
Checks whether the source_path is valid or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
class_source e.g. this.module.Class[@pin]. |
required |
Source code in zenml/core/git_wrapper.py
def is_valid_source(self, source: str) -> bool:
"""
Checks whether the source_path is valid or not.
Args:
source (str): class_source e.g. this.module.Class[@pin].
"""
try:
self.load_source_path_class(source)
except GitException:
return False
return True
load_source_path_class(self, source)
Loads a Python class from the source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
class_source e.g. this.module.Class[@sha] |
required |
Source code in zenml/core/git_wrapper.py
def load_source_path_class(self, source: str) -> Type[Any]:
"""
Loads a Python class from the source.
Args:
source: class_source e.g. this.module.Class[@sha]
"""
source = source.split("@")[0]
pin = source.split("@")[-1]
is_standard = is_standard_pin(pin)
if "@" in source and not is_standard:
logger.debug(
"Pinned step found with git sha. "
"Loading class from git history."
)
module_source = get_module_source_from_source(source)
relative_module_path = get_relative_path_from_module_source(
module_source
)
logger.warning(
"Found source with a pinned sha. Will now checkout "
f"module: {module_source}"
)
# critical step
if not self.check_module_clean(source):
raise GitException(
f"One of the files at {relative_module_path} "
f"is not committed and we "
f"are trying to load that directory from git "
f"history due to a pinned step in the pipeline. "
f"Please commit the file and then run the "
f"pipeline."
)
# Check out the directory at that sha
self.checkout(sha_or_branch=pin, directory=relative_module_path)
# After this point, all exceptions will first undo the above
try:
class_ = source_utils.import_class_by_path(source)
self.reset(relative_module_path)
self.checkout(directory=relative_module_path)
except Exception as e:
self.reset(relative_module_path)
self.checkout(directory=relative_module_path)
raise GitException(
f"A git exception occurred when checking out repository "
f"from git history. Resetting repository to original "
f"state. Original exception: {e}"
)
elif "@" in source and is_standard:
logger.debug(f"Default {APP_NAME} class used. Loading directly.")
# TODO [ENG-165]: Check if ZenML version is installed before loading.
class_ = source_utils.import_class_by_path(source)
else:
logger.debug(
"Unpinned step found with no git sha. Attempting to "
"load class from current repository state."
)
class_ = source_utils.import_class_by_path(source)
return class_
reset(self, directory=None)
Wrapper for git reset HEAD <directory>
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
directory |
Optional[str] |
Relative path to directory to scope checkout |
None |
Source code in zenml/core/git_wrapper.py
def reset(self, directory: Optional[str] = None) -> None:
"""Wrapper for `git reset HEAD <directory>`.
Args:
directory: Relative path to directory to scope checkout
"""
git = self.git_repo.git
git.reset("HEAD", directory)
resolve_class(self, class_)
Resolves a class into a serializable source string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_ |
Type[Any] |
A Python Class reference. |
required |
Returns: source_path e.g. this.module.Class[@pin].
Source code in zenml/core/git_wrapper.py
def resolve_class(self, class_: Type[Any]) -> str:
"""Resolves a class into a serializable source string.
Args:
class_: A Python Class reference.
Returns: source_path e.g. this.module.Class[@pin].
"""
class_source = source_utils.resolve_class(class_)
return self.resolve_class_source(class_source)
resolve_class_source(self, class_source)
Resolves class_source with an optional pin.
Takes source (e.g. this.module.ClassName), and appends relevant
sha to it if the files within module
are all committed. If even one
file is not committed, then returns source
unchanged.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_source |
str |
class_source e.g. this.module.Class |
required |
Source code in zenml/core/git_wrapper.py
def resolve_class_source(self, class_source: str) -> str:
"""Resolves class_source with an optional pin.
Takes source (e.g. this.module.ClassName), and appends relevant
sha to it if the files within `module` are all committed. If even one
file is not committed, then returns `source` unchanged.
Args:
class_source (str): class_source e.g. this.module.Class
"""
if "@" in class_source:
# already pinned
return class_source
if is_standard_source(class_source):
# that means use standard version
return resolve_standard_source(class_source)
# otherwise use Git resolution
if not self.check_module_clean(class_source):
# Return the source path if not clean
logger.warning(
"Found uncommitted file. Pipelines run with this "
"configuration may not be reproducible. Please commit "
"all files in this module and then run the pipeline to "
"ensure reproducibility."
)
return class_source
return class_source + "@" + self.get_current_sha()
stash(self)
Wrapper for git stash
Source code in zenml/core/git_wrapper.py
def stash(self) -> None:
"""Wrapper for git stash"""
git = self.git_repo.git
git.stash()
stash_pop(self)
Wrapper for git stash pop. Only pops if there's something to pop.
Source code in zenml/core/git_wrapper.py
def stash_pop(self) -> None:
"""Wrapper for git stash pop. Only pops if there's something to pop."""
git = self.git_repo.git
if git.stash("list") != "":
git.stash("pop")
local_service
LocalService (BaseComponent)
pydantic-model
Definition of a local service that keeps track of all ZenML components.
Source code in zenml/core/local_service.py
class LocalService(BaseComponent):
"""Definition of a local service that keeps track of all ZenML
components.
"""
stacks: Dict[str, BaseStack] = {}
active_stack_key: str = "local_stack"
metadata_store_map: Dict[str, UUIDSourceTuple] = {}
artifact_store_map: Dict[str, UUIDSourceTuple] = {}
orchestrator_map: Dict[str, UUIDSourceTuple] = {}
container_registry_map: Dict[str, UUIDSourceTuple] = {}
_LOCAL_SERVICE_FILE_NAME = "zenservice.json"
def __init__(self, repo_path: str, **kwargs: Any) -> None:
"""Initializes a LocalService instance.
Args:
repo_path: Path to the repository of this service.
"""
serialization_dir = zenml.io.utils.get_zenml_config_dir(repo_path)
super().__init__(serialization_dir=serialization_dir, **kwargs)
self._repo_path = repo_path
for stack in self.stacks.values():
stack._repo_path = repo_path
def get_serialization_file_name(self) -> str:
"""Return the name of the file where object is serialized."""
return self._LOCAL_SERVICE_FILE_NAME
@property
def metadata_stores(self) -> Dict[str, "BaseMetadataStore"]:
"""Returns all registered metadata stores."""
from zenml.metadata_stores import BaseMetadataStore
return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseMetadataStore._METADATA_STORE_DIR_NAME,
self.metadata_store_map,
self._repo_path,
)
@property
def artifact_stores(self) -> Dict[str, "BaseArtifactStore"]:
"""Returns all registered artifact stores."""
from zenml.artifact_stores import BaseArtifactStore
return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseArtifactStore._ARTIFACT_STORE_DIR_NAME,
self.artifact_store_map,
self._repo_path,
)
@property
def orchestrators(self) -> Dict[str, "BaseOrchestrator"]:
"""Returns all registered orchestrators."""
from zenml.orchestrators import BaseOrchestrator
return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseOrchestrator._ORCHESTRATOR_STORE_DIR_NAME,
self.orchestrator_map,
self._repo_path,
)
@property
def container_registries(self) -> Dict[str, "BaseContainerRegistry"]:
"""Returns all registered container registries."""
from zenml.container_registries import BaseContainerRegistry
return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseContainerRegistry._CONTAINER_REGISTRY_DIR_NAME,
self.container_registry_map,
self._repo_path,
)
def get_active_stack_key(self) -> str:
"""Returns the active stack key."""
return self.active_stack_key
def set_active_stack_key(self, stack_key: str) -> None:
"""Sets the active stack key."""
if stack_key not in self.stacks:
raise DoesNotExistException(
f"Unable to set active stack for key `{stack_key}` because no "
f"stack is registered for this key. Available keys: "
f"{set(self.stacks)}"
)
self.active_stack_key = stack_key
self.update()
def get_stack(self, key: str) -> BaseStack:
"""Return a single stack based on key.
Args:
key: Unique key of stack.
Returns:
Stack specified by key.
"""
logger.debug(f"Fetching stack with key {key}")
if key not in self.stacks:
raise DoesNotExistException(
f"Stack of key `{key}` does not exist. "
f"Available keys: {list(self.stacks.keys())}"
)
return self.stacks[key]
@track(event=REGISTERED_STACK)
def register_stack(self, key: str, stack: BaseStack) -> None:
"""Register a stack.
Args:
key: Unique key for the stack.
stack: Stack to be registered.
"""
logger.debug(
f"Registering stack with key {key}, details: " f"{stack.dict()}"
)
# Check if the individual components actually exist.
# TODO [ENG-190]: Add tests to check cases of registering a stack with a
# non-existing individual component. We can also improve the error
# logging for the CLI while we're at it.
self.get_orchestrator(stack.orchestrator_name)
self.get_artifact_store(stack.artifact_store_name)
self.get_metadata_store(stack.metadata_store_name)
if stack.container_registry_name:
self.get_container_registry(stack.container_registry_name)
if key in self.stacks:
raise AlreadyExistsException(
message=f"Stack `{key}` already exists!"
)
# Add the mapping.
self.stacks[key] = stack
self.update()
def delete_stack(self, key: str) -> None:
"""Delete a stack specified with a key.
Args:
key: Unique key of stack.
"""
_ = self.get_stack(key) # check whether it exists
del self.stacks[key]
self.update()
logger.debug(f"Deleted stack with key: {key}.")
logger.info(
"Deleting a stack currently does not delete the underlying "
"architecture of the stack. It just deletes the reference to it. "
"Therefore please make sure to delete these resources on your "
"own. Also, if this stack was the active stack, please make sure "
"to set a not active stack via `zenml stack set`."
)
def get_artifact_store(self, key: str) -> "BaseArtifactStore":
"""Return a single artifact store based on key.
Args:
key: Unique key of artifact store.
Returns:
Stack specified by key.
"""
logger.debug(f"Fetching artifact_store with key {key}")
if key not in self.artifact_store_map:
raise DoesNotExistException(
f"Stack of key `{key}` does not exist. "
f"Available keys: {list(self.artifact_store_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.artifact_store_map, self._repo_path
)
def register_artifact_store(
self, key: str, artifact_store: "BaseArtifactStore"
) -> None:
"""Register an artifact store.
Args:
artifact_store: Artifact store to be registered.
key: Unique key for the artifact store.
"""
logger.debug(
f"Registering artifact store with key {key}, details: "
f"{artifact_store.dict()}"
)
if key in self.artifact_store_map:
raise AlreadyExistsException(
message=f"Artifact Store `{key}` already exists!"
)
# Add the mapping.
artifact_store.update()
source = source_utils.resolve_class(artifact_store.__class__)
self.artifact_store_map[key] = UUIDSourceTuple(
uuid=artifact_store.uuid, source=source
)
self.update()
# Telemetry
from zenml.core.component_factory import artifact_store_factory
track_event(
REGISTERED_ARTIFACT_STORE,
{
"type": artifact_store_factory.get_component_key(
artifact_store.__class__
)
},
)
def delete_artifact_store(self, key: str) -> None:
"""Delete an artifact_store.
Args:
key: Unique key of artifact_store.
"""
s = self.get_artifact_store(key) # check whether it exists
s.delete()
del self.artifact_store_map[key]
self.update()
logger.debug(f"Deleted artifact_store with key: {key}.")
def get_metadata_store(self, key: str) -> "BaseMetadataStore":
"""Return a single metadata store based on key.
Args:
key: Unique key of metadata store.
Returns:
Metadata store specified by key.
"""
logger.debug(f"Fetching metadata store with key {key}")
if key not in self.metadata_store_map:
raise DoesNotExistException(
f"Metadata store of key `{key}` does not exist. "
f"Available keys: {list(self.metadata_store_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.metadata_store_map, self._repo_path
)
def register_metadata_store(
self, key: str, metadata_store: "BaseMetadataStore"
) -> None:
"""Register a metadata store.
Args:
metadata_store: Metadata store to be registered.
key: Unique key for the metadata store.
"""
logger.debug(
f"Registering metadata store with key {key}, details: "
f"{metadata_store.dict()}"
)
if key in self.metadata_store_map:
raise AlreadyExistsException(
message=f"Metadata store `{key}` already exists!"
)
# Add the mapping.
metadata_store.update()
source = source_utils.resolve_class(metadata_store.__class__)
self.metadata_store_map[key] = UUIDSourceTuple(
uuid=metadata_store.uuid, source=source
)
self.update()
# Telemetry
from zenml.core.component_factory import metadata_store_factory
track_event(
REGISTERED_METADATA_STORE,
{
"type": metadata_store_factory.get_component_key(
metadata_store.__class__
)
},
)
def delete_metadata_store(self, key: str) -> None:
"""Delete a metadata store.
Args:
key: Unique key of metadata store.
"""
s = self.get_metadata_store(key) # check whether it exists
s.delete()
del self.metadata_store_map[key]
self.update()
logger.debug(f"Deleted metadata store with key: {key}.")
def get_orchestrator(self, key: str) -> "BaseOrchestrator":
"""Return a single orchestrator based on key.
Args:
key: Unique key of orchestrator.
Returns:
Orchestrator specified by key.
"""
logger.debug(f"Fetching orchestrator with key {key}")
if key not in self.orchestrator_map:
raise DoesNotExistException(
f"Orchestrator of key `{key}` does not exist. "
f"Available keys: {list(self.orchestrator_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.orchestrator_map, self._repo_path
)
def register_orchestrator(
self, key: str, orchestrator: "BaseOrchestrator"
) -> None:
"""Register an orchestrator.
Args:
orchestrator: Orchestrator to be registered.
key: Unique key for the orchestrator.
"""
logger.debug(
f"Registering orchestrator with key {key}, details: "
f"{orchestrator.dict()}"
)
if key in self.orchestrator_map:
raise AlreadyExistsException(
message=f"Orchestrator `{key}` already exists!"
)
# Add the mapping.
orchestrator.update()
source = source_utils.resolve_class(orchestrator.__class__)
self.orchestrator_map[key] = UUIDSourceTuple(
uuid=orchestrator.uuid, source=source
)
self.update()
# Telemetry
from zenml.core.component_factory import orchestrator_store_factory
track_event(
REGISTERED_ORCHESTRATOR,
{
"type": orchestrator_store_factory.get_component_key(
orchestrator.__class__
)
},
)
def delete_orchestrator(self, key: str) -> None:
"""Delete a orchestrator.
Args:
key: Unique key of orchestrator.
"""
s = self.get_orchestrator(key) # check whether it exists
s.delete()
del self.orchestrator_map[key]
self.update()
logger.debug(f"Deleted orchestrator with key: {key}.")
def get_container_registry(self, key: str) -> "BaseContainerRegistry":
"""Return a single container registry based on key.
Args:
key: Unique key of a container registry.
Returns:
Container registry specified by key.
"""
logger.debug(f"Fetching container registry with key {key}")
if key not in self.container_registry_map:
raise DoesNotExistException(
f"Container registry of key `{key}` does not exist. "
f"Available keys: {list(self.container_registry_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.container_registry_map, self._repo_path
)
@track(event=REGISTERED_CONTAINER_REGISTRY)
def register_container_registry(
self, key: str, container_registry: "BaseContainerRegistry"
) -> None:
"""Register a container registry.
Args:
container_registry: Container registry to be registered.
key: Unique key for the container registry.
"""
logger.debug(
f"Registering container registry with key {key}, details: "
f"{container_registry.dict()}"
)
if key in self.container_registry_map:
raise AlreadyExistsException(
message=f"Container registry `{key}` already exists!"
)
# Add the mapping.
container_registry.update()
source = source_utils.resolve_class(container_registry.__class__)
self.container_registry_map[key] = UUIDSourceTuple(
uuid=container_registry.uuid, source=source
)
self.update()
def delete_container_registry(self, key: str) -> None:
"""Delete a container registry.
Args:
key: Unique key of the container registry.
"""
container_registry = self.get_container_registry(key)
container_registry.delete()
del self.container_registry_map[key]
self.update()
logger.debug(f"Deleted container registry with key: {key}.")
def delete(self) -> None:
"""Deletes the entire service. Dangerous operation"""
for m in self.metadata_stores.values():
m.delete()
for a in self.artifact_stores.values():
a.delete()
for o in self.orchestrators.values():
o.delete()
for c in self.container_registries.values():
c.delete()
super().delete()
artifact_stores: Dict[str, BaseArtifactStore]
property
readonly
Returns all registered artifact stores.
container_registries: Dict[str, BaseContainerRegistry]
property
readonly
Returns all registered container registries.
metadata_stores: Dict[str, BaseMetadataStore]
property
readonly
Returns all registered metadata stores.
orchestrators: Dict[str, BaseOrchestrator]
property
readonly
Returns all registered orchestrators.
__init__(self, repo_path, **kwargs)
special
Initializes a LocalService instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
repo_path |
str |
Path to the repository of this service. |
required |
Source code in zenml/core/local_service.py
def __init__(self, repo_path: str, **kwargs: Any) -> None:
"""Initializes a LocalService instance.
Args:
repo_path: Path to the repository of this service.
"""
serialization_dir = zenml.io.utils.get_zenml_config_dir(repo_path)
super().__init__(serialization_dir=serialization_dir, **kwargs)
self._repo_path = repo_path
for stack in self.stacks.values():
stack._repo_path = repo_path
delete(self)
Deletes the entire service. Dangerous operation
Source code in zenml/core/local_service.py
def delete(self) -> None:
"""Deletes the entire service. Dangerous operation"""
for m in self.metadata_stores.values():
m.delete()
for a in self.artifact_stores.values():
a.delete()
for o in self.orchestrators.values():
o.delete()
for c in self.container_registries.values():
c.delete()
super().delete()
delete_artifact_store(self, key)
Delete an artifact_store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of artifact_store. |
required |
Source code in zenml/core/local_service.py
def delete_artifact_store(self, key: str) -> None:
"""Delete an artifact_store.
Args:
key: Unique key of artifact_store.
"""
s = self.get_artifact_store(key) # check whether it exists
s.delete()
del self.artifact_store_map[key]
self.update()
logger.debug(f"Deleted artifact_store with key: {key}.")
delete_container_registry(self, key)
Delete a container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of the container registry. |
required |
Source code in zenml/core/local_service.py
def delete_container_registry(self, key: str) -> None:
"""Delete a container registry.
Args:
key: Unique key of the container registry.
"""
container_registry = self.get_container_registry(key)
container_registry.delete()
del self.container_registry_map[key]
self.update()
logger.debug(f"Deleted container registry with key: {key}.")
delete_metadata_store(self, key)
Delete a metadata store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of metadata store. |
required |
Source code in zenml/core/local_service.py
def delete_metadata_store(self, key: str) -> None:
"""Delete a metadata store.
Args:
key: Unique key of metadata store.
"""
s = self.get_metadata_store(key) # check whether it exists
s.delete()
del self.metadata_store_map[key]
self.update()
logger.debug(f"Deleted metadata store with key: {key}.")
delete_orchestrator(self, key)
Delete a orchestrator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of orchestrator. |
required |
Source code in zenml/core/local_service.py
def delete_orchestrator(self, key: str) -> None:
"""Delete a orchestrator.
Args:
key: Unique key of orchestrator.
"""
s = self.get_orchestrator(key) # check whether it exists
s.delete()
del self.orchestrator_map[key]
self.update()
logger.debug(f"Deleted orchestrator with key: {key}.")
delete_stack(self, key)
Delete a stack specified with a key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of stack. |
required |
Source code in zenml/core/local_service.py
def delete_stack(self, key: str) -> None:
"""Delete a stack specified with a key.
Args:
key: Unique key of stack.
"""
_ = self.get_stack(key) # check whether it exists
del self.stacks[key]
self.update()
logger.debug(f"Deleted stack with key: {key}.")
logger.info(
"Deleting a stack currently does not delete the underlying "
"architecture of the stack. It just deletes the reference to it. "
"Therefore please make sure to delete these resources on your "
"own. Also, if this stack was the active stack, please make sure "
"to set a not active stack via `zenml stack set`."
)
get_active_stack_key(self)
Returns the active stack key.
Source code in zenml/core/local_service.py
def get_active_stack_key(self) -> str:
"""Returns the active stack key."""
return self.active_stack_key
get_artifact_store(self, key)
Return a single artifact store based on key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of artifact store. |
required |
Returns:
Type | Description |
---|---|
BaseArtifactStore |
Stack specified by key. |
Source code in zenml/core/local_service.py
def get_artifact_store(self, key: str) -> "BaseArtifactStore":
"""Return a single artifact store based on key.
Args:
key: Unique key of artifact store.
Returns:
Stack specified by key.
"""
logger.debug(f"Fetching artifact_store with key {key}")
if key not in self.artifact_store_map:
raise DoesNotExistException(
f"Stack of key `{key}` does not exist. "
f"Available keys: {list(self.artifact_store_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.artifact_store_map, self._repo_path
)
get_container_registry(self, key)
Return a single container registry based on key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of a container registry. |
required |
Returns:
Type | Description |
---|---|
BaseContainerRegistry |
Container registry specified by key. |
Source code in zenml/core/local_service.py
def get_container_registry(self, key: str) -> "BaseContainerRegistry":
"""Return a single container registry based on key.
Args:
key: Unique key of a container registry.
Returns:
Container registry specified by key.
"""
logger.debug(f"Fetching container registry with key {key}")
if key not in self.container_registry_map:
raise DoesNotExistException(
f"Container registry of key `{key}` does not exist. "
f"Available keys: {list(self.container_registry_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.container_registry_map, self._repo_path
)
get_metadata_store(self, key)
Return a single metadata store based on key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of metadata store. |
required |
Returns:
Type | Description |
---|---|
BaseMetadataStore |
Metadata store specified by key. |
Source code in zenml/core/local_service.py
def get_metadata_store(self, key: str) -> "BaseMetadataStore":
"""Return a single metadata store based on key.
Args:
key: Unique key of metadata store.
Returns:
Metadata store specified by key.
"""
logger.debug(f"Fetching metadata store with key {key}")
if key not in self.metadata_store_map:
raise DoesNotExistException(
f"Metadata store of key `{key}` does not exist. "
f"Available keys: {list(self.metadata_store_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.metadata_store_map, self._repo_path
)
get_orchestrator(self, key)
Return a single orchestrator based on key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of orchestrator. |
required |
Returns:
Type | Description |
---|---|
BaseOrchestrator |
Orchestrator specified by key. |
Source code in zenml/core/local_service.py
def get_orchestrator(self, key: str) -> "BaseOrchestrator":
"""Return a single orchestrator based on key.
Args:
key: Unique key of orchestrator.
Returns:
Orchestrator specified by key.
"""
logger.debug(f"Fetching orchestrator with key {key}")
if key not in self.orchestrator_map:
raise DoesNotExistException(
f"Orchestrator of key `{key}` does not exist. "
f"Available keys: {list(self.orchestrator_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.orchestrator_map, self._repo_path
)
get_serialization_file_name(self)
Return the name of the file where object is serialized.
Source code in zenml/core/local_service.py
def get_serialization_file_name(self) -> str:
"""Return the name of the file where object is serialized."""
return self._LOCAL_SERVICE_FILE_NAME
get_stack(self, key)
Return a single stack based on key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key of stack. |
required |
Returns:
Type | Description |
---|---|
BaseStack |
Stack specified by key. |
Source code in zenml/core/local_service.py
def get_stack(self, key: str) -> BaseStack:
"""Return a single stack based on key.
Args:
key: Unique key of stack.
Returns:
Stack specified by key.
"""
logger.debug(f"Fetching stack with key {key}")
if key not in self.stacks:
raise DoesNotExistException(
f"Stack of key `{key}` does not exist. "
f"Available keys: {list(self.stacks.keys())}"
)
return self.stacks[key]
register_artifact_store(self, key, artifact_store)
Register an artifact store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_store |
BaseArtifactStore |
Artifact store to be registered. |
required |
key |
str |
Unique key for the artifact store. |
required |
Source code in zenml/core/local_service.py
def register_artifact_store(
self, key: str, artifact_store: "BaseArtifactStore"
) -> None:
"""Register an artifact store.
Args:
artifact_store: Artifact store to be registered.
key: Unique key for the artifact store.
"""
logger.debug(
f"Registering artifact store with key {key}, details: "
f"{artifact_store.dict()}"
)
if key in self.artifact_store_map:
raise AlreadyExistsException(
message=f"Artifact Store `{key}` already exists!"
)
# Add the mapping.
artifact_store.update()
source = source_utils.resolve_class(artifact_store.__class__)
self.artifact_store_map[key] = UUIDSourceTuple(
uuid=artifact_store.uuid, source=source
)
self.update()
# Telemetry
from zenml.core.component_factory import artifact_store_factory
track_event(
REGISTERED_ARTIFACT_STORE,
{
"type": artifact_store_factory.get_component_key(
artifact_store.__class__
)
},
)
register_container_registry(*args, **kwargs)
Inner decorator function.
Source code in zenml/core/local_service.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function."""
track_event(event_name, metadata=metadata)
result = func(*args, **kwargs)
return result
register_metadata_store(self, key, metadata_store)
Register a metadata store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata_store |
BaseMetadataStore |
Metadata store to be registered. |
required |
key |
str |
Unique key for the metadata store. |
required |
Source code in zenml/core/local_service.py
def register_metadata_store(
self, key: str, metadata_store: "BaseMetadataStore"
) -> None:
"""Register a metadata store.
Args:
metadata_store: Metadata store to be registered.
key: Unique key for the metadata store.
"""
logger.debug(
f"Registering metadata store with key {key}, details: "
f"{metadata_store.dict()}"
)
if key in self.metadata_store_map:
raise AlreadyExistsException(
message=f"Metadata store `{key}` already exists!"
)
# Add the mapping.
metadata_store.update()
source = source_utils.resolve_class(metadata_store.__class__)
self.metadata_store_map[key] = UUIDSourceTuple(
uuid=metadata_store.uuid, source=source
)
self.update()
# Telemetry
from zenml.core.component_factory import metadata_store_factory
track_event(
REGISTERED_METADATA_STORE,
{
"type": metadata_store_factory.get_component_key(
metadata_store.__class__
)
},
)
register_orchestrator(self, key, orchestrator)
Register an orchestrator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
orchestrator |
BaseOrchestrator |
Orchestrator to be registered. |
required |
key |
str |
Unique key for the orchestrator. |
required |
Source code in zenml/core/local_service.py
def register_orchestrator(
self, key: str, orchestrator: "BaseOrchestrator"
) -> None:
"""Register an orchestrator.
Args:
orchestrator: Orchestrator to be registered.
key: Unique key for the orchestrator.
"""
logger.debug(
f"Registering orchestrator with key {key}, details: "
f"{orchestrator.dict()}"
)
if key in self.orchestrator_map:
raise AlreadyExistsException(
message=f"Orchestrator `{key}` already exists!"
)
# Add the mapping.
orchestrator.update()
source = source_utils.resolve_class(orchestrator.__class__)
self.orchestrator_map[key] = UUIDSourceTuple(
uuid=orchestrator.uuid, source=source
)
self.update()
# Telemetry
from zenml.core.component_factory import orchestrator_store_factory
track_event(
REGISTERED_ORCHESTRATOR,
{
"type": orchestrator_store_factory.get_component_key(
orchestrator.__class__
)
},
)
register_stack(*args, **kwargs)
Inner decorator function.
Source code in zenml/core/local_service.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function."""
track_event(event_name, metadata=metadata)
result = func(*args, **kwargs)
return result
set_active_stack_key(self, stack_key)
Sets the active stack key.
Source code in zenml/core/local_service.py
def set_active_stack_key(self, stack_key: str) -> None:
"""Sets the active stack key."""
if stack_key not in self.stacks:
raise DoesNotExistException(
f"Unable to set active stack for key `{stack_key}` because no "
f"stack is registered for this key. Available keys: "
f"{set(self.stacks)}"
)
self.active_stack_key = stack_key
self.update()
mapping_utils
UUIDSourceTuple (BaseModel)
pydantic-model
Container used to store UUID and source information of a single BaseComponent subclass.
Attributes:
Name | Type | Description |
---|---|---|
uuid |
UUID |
Identifier of the BaseComponent |
source |
str |
Contains the fully qualified class name and information about a git hash/tag. E.g. foo.bar.BaseComponentSubclass@git_tag |
Source code in zenml/core/mapping_utils.py
class UUIDSourceTuple(BaseModel):
"""Container used to store UUID and source information
of a single BaseComponent subclass.
Attributes:
uuid: Identifier of the BaseComponent
source: Contains the fully qualified class name and information
about a git hash/tag. E.g. foo.bar.BaseComponentSubclass@git_tag
"""
uuid: UUID
source: str
get_component_from_key(key, mapping, repo_path)
Given a key and a mapping, return an initialized component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
Unique key. |
required |
mapping |
Dict[str, zenml.core.mapping_utils.UUIDSourceTuple] |
Dict of type str -> UUIDSourceTuple. |
required |
repo_path |
str |
Path to the repo from which to load the component. |
required |
Returns:
Type | Description |
---|---|
BaseComponent |
An object which is a subclass of type BaseComponent. |
Source code in zenml/core/mapping_utils.py
def get_component_from_key(
key: str, mapping: Dict[str, UUIDSourceTuple], repo_path: str
) -> BaseComponent:
"""Given a key and a mapping, return an initialized component.
Args:
key: Unique key.
mapping: Dict of type str -> UUIDSourceTuple.
repo_path: Path to the repo from which to load the component.
Returns:
An object which is a subclass of type BaseComponent.
"""
tuple_ = mapping[key]
class_ = source_utils.load_source_path_class(tuple_.source)
if not issubclass(class_, BaseComponent):
raise TypeError("")
return class_(uuid=tuple_.uuid, repo_path=repo_path) # type: ignore[call-arg] # noqa
get_components_from_store(store_name, mapping, repo_path)
Returns a list of components from a store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
store_name |
str |
Name of the store. |
required |
mapping |
Dict[str, zenml.core.mapping_utils.UUIDSourceTuple] |
Dict of type str -> UUIDSourceTuple. |
required |
repo_path |
str |
Path to the repo from which to load the components. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.core.base_component.BaseComponent] |
A dict of objects which are a subclass of type BaseComponent. |
Source code in zenml/core/mapping_utils.py
def get_components_from_store(
store_name: str, mapping: Dict[str, UUIDSourceTuple], repo_path: str
) -> Dict[str, BaseComponent]:
"""Returns a list of components from a store.
Args:
store_name: Name of the store.
mapping: Dict of type str -> UUIDSourceTuple.
repo_path: Path to the repo from which to load the components.
Returns:
A dict of objects which are a subclass of type BaseComponent.
"""
store_dir = os.path.join(
zenml.io.utils.get_zenml_config_dir(repo_path),
store_name,
)
comps = {}
for fnames in fileio.list_dir(store_dir, only_file_names=True):
uuid = Path(fnames).stem
key = get_key_from_uuid(UUID(uuid), mapping)
comps[key] = get_component_from_key(key, mapping, repo_path)
return comps
get_key_from_uuid(uuid, mapping)
Return the key that points to a certain uuid in a mapping.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uuid |
UUID |
uuid to query. |
required |
mapping |
Dict[str, zenml.core.mapping_utils.UUIDSourceTuple] |
Dict mapping keys to UUIDs and source information. |
required |
Returns:
Type | Description |
---|---|
str |
Returns the key from the mapping. |
Source code in zenml/core/mapping_utils.py
def get_key_from_uuid(uuid: UUID, mapping: Dict[str, UUIDSourceTuple]) -> str:
"""Return the key that points to a certain uuid in a mapping.
Args:
uuid: uuid to query.
mapping: Dict mapping keys to UUIDs and source information.
Returns:
Returns the key from the mapping.
"""
inverted_map = {v.uuid: k for k, v in mapping.items()}
return inverted_map[uuid]
repo
Base ZenML repository
Repository
ZenML repository definition.
Every ZenML project exists inside a ZenML repository.
Source code in zenml/core/repo.py
class Repository:
"""ZenML repository definition.
Every ZenML project exists inside a ZenML repository.
"""
def __init__(self, path: Optional[str] = None):
"""
Construct reference to a ZenML repository.
Args:
path (str): Path to root of repository
"""
self.path = zenml.io.utils.get_zenml_dir(path)
self.service = LocalService(repo_path=self.path)
try:
self.git_wrapper = GitWrapper(self.path)
except InvalidGitRepositoryError:
self.git_wrapper = None # type: ignore[assignment]
@staticmethod
def init_repo(path: str = os.getcwd()) -> None:
"""Initializes a ZenML repository.
Args:
path: Path where the ZenML repository should be created.
Raises:
InitializationException: If a ZenML repository already exists at
the given path.
"""
if zenml.io.utils.is_zenml_dir(path):
raise InitializationException(
f"A ZenML repository already exists at path '{path}'."
)
# Create the base dir
zen_dir = os.path.join(path, ZENML_DIR_NAME)
fileio.create_dir_recursive_if_not_exists(zen_dir)
from zenml.artifact_stores import LocalArtifactStore
from zenml.metadata_stores import SQLiteMetadataStore
from zenml.orchestrators import LocalOrchestrator
service = LocalService(repo_path=path)
artifact_store_path = os.path.join(
zenml.io.utils.get_global_config_directory(),
"local_stores",
str(service.uuid),
)
metadata_store_path = os.path.join(artifact_store_path, "metadata.db")
service.register_artifact_store(
"local_artifact_store",
LocalArtifactStore(path=artifact_store_path, repo_path=path),
)
service.register_metadata_store(
"local_metadata_store",
SQLiteMetadataStore(uri=metadata_store_path, repo_path=path),
)
service.register_orchestrator(
"local_orchestrator", LocalOrchestrator(repo_path=path)
)
service.register_stack(
"local_stack",
BaseStack(
metadata_store_name="local_metadata_store",
artifact_store_name="local_artifact_store",
orchestrator_name="local_orchestrator",
),
)
service.set_active_stack_key("local_stack")
def get_git_wrapper(self) -> GitWrapper:
"""Returns the git wrapper for the repo."""
return self.git_wrapper
def get_service(self) -> LocalService:
"""Returns the active service. For now, always local."""
return self.service
@track(event=SET_STACK)
def set_active_stack(self, stack_key: str) -> None:
"""Set the active stack for the repo. This change is local for the
machine.
Args:
stack_key: Key of the stack to set active.
"""
self.service.set_active_stack_key(stack_key)
def get_active_stack_key(self) -> str:
"""Get the active stack key from global config.
Returns:
Currently active stacks key.
"""
return self.service.get_active_stack_key()
def get_active_stack(self) -> BaseStack:
"""Get the active stack from global config.
Returns:
Currently active stack.
"""
return self.service.get_stack(self.get_active_stack_key())
@track(event=GET_PIPELINES)
def get_pipelines(
self, stack_key: Optional[str] = None
) -> List[PipelineView]:
"""Returns a list of all pipelines.
Args:
stack_key: If specified, pipelines in the metadata store of the
given stack are returned. Otherwise pipelines in the metadata
store of the currently active stack are returned.
"""
stack_key = stack_key or self.get_active_stack_key()
metadata_store = self.service.get_stack(stack_key).metadata_store
return metadata_store.get_pipelines()
@track(event=GET_PIPELINE)
def get_pipeline(
self, pipeline_name: str, stack_key: Optional[str] = None
) -> Optional[PipelineView]:
"""Returns a pipeline for the given name or `None` if it doesn't exist.
Args:
pipeline_name: Name of the pipeline.
stack_key: If specified, pipelines in the metadata store of the
given stack are returned. Otherwise pipelines in the metadata
store of the currently active stack are returned.
"""
stack_key = stack_key or self.get_active_stack_key()
metadata_store = self.service.get_stack(stack_key).metadata_store
return metadata_store.get_pipeline(pipeline_name)
def clean(self) -> None:
"""Deletes associated metadata store, pipelines dir and artifacts"""
raise NotImplementedError
__init__(self, path=None)
special
Construct reference to a ZenML repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Path to root of repository |
None |
Source code in zenml/core/repo.py
def __init__(self, path: Optional[str] = None):
"""
Construct reference to a ZenML repository.
Args:
path (str): Path to root of repository
"""
self.path = zenml.io.utils.get_zenml_dir(path)
self.service = LocalService(repo_path=self.path)
try:
self.git_wrapper = GitWrapper(self.path)
except InvalidGitRepositoryError:
self.git_wrapper = None # type: ignore[assignment]
clean(self)
Deletes associated metadata store, pipelines dir and artifacts
Source code in zenml/core/repo.py
def clean(self) -> None:
"""Deletes associated metadata store, pipelines dir and artifacts"""
raise NotImplementedError
get_active_stack(self)
Get the active stack from global config.
Returns:
Type | Description |
---|---|
BaseStack |
Currently active stack. |
Source code in zenml/core/repo.py
def get_active_stack(self) -> BaseStack:
"""Get the active stack from global config.
Returns:
Currently active stack.
"""
return self.service.get_stack(self.get_active_stack_key())
get_active_stack_key(self)
Get the active stack key from global config.
Returns:
Type | Description |
---|---|
str |
Currently active stacks key. |
Source code in zenml/core/repo.py
def get_active_stack_key(self) -> str:
"""Get the active stack key from global config.
Returns:
Currently active stacks key.
"""
return self.service.get_active_stack_key()
get_git_wrapper(self)
Returns the git wrapper for the repo.
Source code in zenml/core/repo.py
def get_git_wrapper(self) -> GitWrapper:
"""Returns the git wrapper for the repo."""
return self.git_wrapper
get_pipeline(*args, **kwargs)
Inner decorator function.
Source code in zenml/core/repo.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function."""
track_event(event_name, metadata=metadata)
result = func(*args, **kwargs)
return result
get_pipelines(*args, **kwargs)
Inner decorator function.
Source code in zenml/core/repo.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function."""
track_event(event_name, metadata=metadata)
result = func(*args, **kwargs)
return result
get_service(self)
Returns the active service. For now, always local.
Source code in zenml/core/repo.py
def get_service(self) -> LocalService:
"""Returns the active service. For now, always local."""
return self.service
init_repo(path='/home/apenner/PycharmProjects/zenml')
staticmethod
Initializes a ZenML repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Path where the ZenML repository should be created. |
'/home/apenner/PycharmProjects/zenml' |
Exceptions:
Type | Description |
---|---|
InitializationException |
If a ZenML repository already exists at the given path. |
Source code in zenml/core/repo.py
@staticmethod
def init_repo(path: str = os.getcwd()) -> None:
"""Initializes a ZenML repository.
Args:
path: Path where the ZenML repository should be created.
Raises:
InitializationException: If a ZenML repository already exists at
the given path.
"""
if zenml.io.utils.is_zenml_dir(path):
raise InitializationException(
f"A ZenML repository already exists at path '{path}'."
)
# Create the base dir
zen_dir = os.path.join(path, ZENML_DIR_NAME)
fileio.create_dir_recursive_if_not_exists(zen_dir)
from zenml.artifact_stores import LocalArtifactStore
from zenml.metadata_stores import SQLiteMetadataStore
from zenml.orchestrators import LocalOrchestrator
service = LocalService(repo_path=path)
artifact_store_path = os.path.join(
zenml.io.utils.get_global_config_directory(),
"local_stores",
str(service.uuid),
)
metadata_store_path = os.path.join(artifact_store_path, "metadata.db")
service.register_artifact_store(
"local_artifact_store",
LocalArtifactStore(path=artifact_store_path, repo_path=path),
)
service.register_metadata_store(
"local_metadata_store",
SQLiteMetadataStore(uri=metadata_store_path, repo_path=path),
)
service.register_orchestrator(
"local_orchestrator", LocalOrchestrator(repo_path=path)
)
service.register_stack(
"local_stack",
BaseStack(
metadata_store_name="local_metadata_store",
artifact_store_name="local_artifact_store",
orchestrator_name="local_orchestrator",
),
)
service.set_active_stack_key("local_stack")
set_active_stack(*args, **kwargs)
Inner decorator function.
Source code in zenml/core/repo.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function."""
track_event(event_name, metadata=metadata)
result = func(*args, **kwargs)
return result
utils
define_json_config_settings_source(config_dir, config_name)
Define a function to essentially deserialize a model from a serialized json config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_dir |
str |
A path to a dir where we want the config file to exist. |
required |
config_name |
str |
Full name of config file. |
required |
Returns:
Type | Description |
---|---|
Callable[[BaseSettings], Dict[str, Any]] |
A |
Source code in zenml/core/utils.py
def define_json_config_settings_source(
config_dir: str, config_name: str
) -> SettingsSourceCallable:
"""
Define a function to essentially deserialize a model from a serialized
json config.
Args:
config_dir: A path to a dir where we want the config file to exist.
config_name: Full name of config file.
Returns:
A `json_config_settings_source` callable reading from the passed path.
"""
def json_config_settings_source(settings: BaseSettings) -> Dict[str, Any]:
"""
A simple settings source that loads variables from a YAML file
at the project's root.
Here we happen to choose to use the `env_file_encoding` from Config
when reading the config json file.
Args:
settings (BaseSettings): BaseSettings from pydantic.
Returns:
A dict with all configuration, empty dict if config not found.
"""
full_path = Path(config_dir) / config_name
logger.debug(f"Parsing file: {full_path}")
if fileio.file_exists(str(full_path)):
return cast(Dict[str, Any], yaml_utils.read_json(str(full_path)))
return {}
return json_config_settings_source
generate_customise_sources(file_dir, file_name)
Generate a customise_sources function as defined here:
https://pydantic-docs.helpmanual.io/usage/settings/. This function
generates a function that configures the priorities of the sources through
which the model is loaded. The important thing to note here is that the
define_json_config_settings_source
is dynamically generated with the
provided file_dir and file_name. This allows us to dynamically generate
a file name for the serialization and deserialization of the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_dir |
str |
Dir where file is stored. |
required |
file_name |
str |
Name of the file to persist. |
required |
Returns:
Type | Description |
---|---|
Callable[[Type[pydantic.env_settings.BaseSettings.Config], Callable[[BaseSettings], Dict[str, Any]], Callable[[BaseSettings], Dict[str, Any]], Callable[[BaseSettings], Dict[str, Any]]], Tuple[Callable[[BaseSettings], Dict[str, Any]], ...]] |
A |
Source code in zenml/core/utils.py
def generate_customise_sources(
file_dir: str, file_name: str
) -> Callable[
[
Type[BaseSettings.Config],
SettingsSourceCallable,
SettingsSourceCallable,
SettingsSourceCallable,
],
Tuple[SettingsSourceCallable, ...],
]:
"""Generate a customise_sources function as defined here:
https://pydantic-docs.helpmanual.io/usage/settings/. This function
generates a function that configures the priorities of the sources through
which the model is loaded. The important thing to note here is that the
`define_json_config_settings_source` is dynamically generated with the
provided file_dir and file_name. This allows us to dynamically generate
a file name for the serialization and deserialization of the model.
Args:
file_dir: Dir where file is stored.
file_name: Name of the file to persist.
Returns:
A `customise_sources` class method to be defined the a Pydantic
BaseSettings inner Config class.
"""
def customise_sources(
cls: Type[BaseSettings.Config],
init_settings: SettingsSourceCallable,
env_settings: SettingsSourceCallable,
file_secret_settings: SettingsSourceCallable,
) -> Tuple[SettingsSourceCallable, ...]:
"""Defines precedence of sources to read/write settings from."""
return (
init_settings,
env_settings,
define_json_config_settings_source(
file_dir,
file_name,
),
file_secret_settings,
)
return classmethod(customise_sources) # type: ignore[return-value]