Code Repositories
zenml.code_repositories
special
Initialization of the ZenML code repository base abstraction.
base_code_repository
Base class for code repositories.
BaseCodeRepository (ABC)
Base class for code repositories.
Code repositories are used to connect to a remote code repository and store information about the repository, such as the URL, the owner, the repository name, and the host. They also provide methods to download files from the repository when a pipeline is run remotely.
Source code in zenml/code_repositories/base_code_repository.py
class BaseCodeRepository(ABC):
"""Base class for code repositories.
Code repositories are used to connect to a remote code repository and
store information about the repository, such as the URL, the owner,
the repository name, and the host. They also provide methods to
download files from the repository when a pipeline is run remotely.
"""
def __init__(
self,
id: UUID,
config: Dict[str, Any],
) -> None:
"""Initializes a code repository.
Args:
id: The ID of the code repository.
config: The config of the code repository.
"""
self._id = id
self._config = config
self.login()
@property
def config(self) -> "BaseCodeRepositoryConfig":
"""Config class for Code Repository.
Returns:
The config class.
"""
return BaseCodeRepositoryConfig(**self._config)
@classmethod
def from_model(cls, model: CodeRepositoryResponse) -> "BaseCodeRepository":
"""Loads a code repository from a model.
Args:
model: The CodeRepositoryResponseModel to load from.
Returns:
The loaded code repository object.
"""
class_: Type[
BaseCodeRepository
] = source_utils.load_and_validate_class(
source=model.source, expected_class=BaseCodeRepository
)
return class_(id=model.id, config=model.config)
@property
def id(self) -> UUID:
"""ID of the code repository.
Returns:
The ID of the code repository.
"""
return self._id
@property
def requirements(self) -> Set[str]:
"""Set of PyPI requirements for the repository.
Returns:
A set of PyPI requirements for the repository.
"""
from zenml.integrations.utils import get_requirements_for_module
return set(get_requirements_for_module(self.__module__))
@abstractmethod
def login(self) -> None:
"""Logs into the code repository.
This method is called when the code repository is initialized.
It should be used to authenticate with the code repository.
Raises:
RuntimeError: If the login fails.
"""
pass
@abstractmethod
def download_files(
self, commit: str, directory: str, repo_sub_directory: Optional[str]
) -> None:
"""Downloads files from the code repository to a local directory.
Args:
commit: The commit hash to download files from.
directory: The directory to download files to.
repo_sub_directory: The subdirectory in the repository to
download files from.
Raises:
RuntimeError: If the download fails.
"""
pass
@abstractmethod
def get_local_context(
self, path: str
) -> Optional["LocalRepositoryContext"]:
"""Gets a local repository context from a path.
Args:
path: The path to the local repository.
Returns:
The local repository context object.
"""
pass
config: BaseCodeRepositoryConfig
property
readonly
Config class for Code Repository.
Returns:
Type | Description |
---|---|
BaseCodeRepositoryConfig |
The config class. |
id: UUID
property
readonly
ID of the code repository.
Returns:
Type | Description |
---|---|
UUID |
The ID of the code repository. |
requirements: Set[str]
property
readonly
Set of PyPI requirements for the repository.
Returns:
Type | Description |
---|---|
Set[str] |
A set of PyPI requirements for the repository. |
__init__(self, id, config)
special
Initializes a code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
UUID |
The ID of the code repository. |
required |
config |
Dict[str, Any] |
The config of the code repository. |
required |
Source code in zenml/code_repositories/base_code_repository.py
def __init__(
self,
id: UUID,
config: Dict[str, Any],
) -> None:
"""Initializes a code repository.
Args:
id: The ID of the code repository.
config: The config of the code repository.
"""
self._id = id
self._config = config
self.login()
download_files(self, commit, directory, repo_sub_directory)
Downloads files from the code repository to a local directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
commit |
str |
The commit hash to download files from. |
required |
directory |
str |
The directory to download files to. |
required |
repo_sub_directory |
Optional[str] |
The subdirectory in the repository to download files from. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the download fails. |
Source code in zenml/code_repositories/base_code_repository.py
@abstractmethod
def download_files(
self, commit: str, directory: str, repo_sub_directory: Optional[str]
) -> None:
"""Downloads files from the code repository to a local directory.
Args:
commit: The commit hash to download files from.
directory: The directory to download files to.
repo_sub_directory: The subdirectory in the repository to
download files from.
Raises:
RuntimeError: If the download fails.
"""
pass
from_model(model)
classmethod
Loads a code repository from a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
CodeRepositoryResponse |
The CodeRepositoryResponseModel to load from. |
required |
Returns:
Type | Description |
---|---|
BaseCodeRepository |
The loaded code repository object. |
Source code in zenml/code_repositories/base_code_repository.py
@classmethod
def from_model(cls, model: CodeRepositoryResponse) -> "BaseCodeRepository":
"""Loads a code repository from a model.
Args:
model: The CodeRepositoryResponseModel to load from.
Returns:
The loaded code repository object.
"""
class_: Type[
BaseCodeRepository
] = source_utils.load_and_validate_class(
source=model.source, expected_class=BaseCodeRepository
)
return class_(id=model.id, config=model.config)
get_local_context(self, path)
Gets a local repository context from a path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
The path to the local repository. |
required |
Returns:
Type | Description |
---|---|
Optional[LocalRepositoryContext] |
The local repository context object. |
Source code in zenml/code_repositories/base_code_repository.py
@abstractmethod
def get_local_context(
self, path: str
) -> Optional["LocalRepositoryContext"]:
"""Gets a local repository context from a path.
Args:
path: The path to the local repository.
Returns:
The local repository context object.
"""
pass
login(self)
Logs into the code repository.
This method is called when the code repository is initialized. It should be used to authenticate with the code repository.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the login fails. |
Source code in zenml/code_repositories/base_code_repository.py
@abstractmethod
def login(self) -> None:
"""Logs into the code repository.
This method is called when the code repository is initialized.
It should be used to authenticate with the code repository.
Raises:
RuntimeError: If the login fails.
"""
pass
BaseCodeRepositoryConfig (SecretReferenceMixin, ABC)
pydantic-model
Base config for code repositories.
Source code in zenml/code_repositories/base_code_repository.py
class BaseCodeRepositoryConfig(SecretReferenceMixin, ABC):
"""Base config for code repositories."""
git
special
Initialization of the local git repository context.
local_git_repository_context
Implementation of the Local git repository context.
LocalGitRepositoryContext (LocalRepositoryContext)
Local git repository context.
Source code in zenml/code_repositories/git/local_git_repository_context.py
class LocalGitRepositoryContext(LocalRepositoryContext):
"""Local git repository context."""
def __init__(
self, code_repository_id: UUID, git_repo: "Repo", remote_name: str
):
"""Initializes a local git repository context.
Args:
code_repository_id: The ID of the code repository.
git_repo: The git repo.
remote_name: Name of the remote.
"""
super().__init__(code_repository_id=code_repository_id)
self._git_repo = git_repo
self._remote = git_repo.remote(name=remote_name)
@classmethod
def at(
cls,
path: str,
code_repository_id: UUID,
remote_url_validation_callback: Callable[[str], bool],
) -> Optional["LocalGitRepositoryContext"]:
"""Returns a local git repository at the given path.
Args:
path: The path to the local git repository.
code_repository_id: The ID of the code repository.
remote_url_validation_callback: A callback that validates the
remote URL of the git repository.
Returns:
A local git repository if the path is a valid git repository
and the remote URL is valid, otherwise None.
"""
try:
# These imports fail when git is not installed on the machine
from git.exc import InvalidGitRepositoryError
from git.repo.base import Repo
except ImportError:
return None
try:
git_repo = Repo(path=path, search_parent_directories=True)
except InvalidGitRepositoryError:
return None
remote_name = None
for remote in git_repo.remotes:
if remote_url_validation_callback(remote.url):
remote_name = remote.name
break
if not remote_name:
return None
return LocalGitRepositoryContext(
code_repository_id=code_repository_id,
git_repo=git_repo,
remote_name=remote_name,
)
@property
def git_repo(self) -> "Repo":
"""The git repo.
Returns:
The git repo object of the local git repository.
"""
return self._git_repo
@property
def remote(self) -> "Remote":
"""The git remote.
Returns:
The remote of the git repo object of the local git repository.
"""
return self._remote
@property
def root(self) -> str:
"""The root of the git repo.
Returns:
The root of the git repo.
"""
assert self.git_repo.working_dir
return str(self.git_repo.working_dir)
@property
def is_dirty(self) -> bool:
"""Whether the git repo is dirty.
A repository counts as dirty if it has any untracked or uncommitted
changes.
Returns:
True if the git repo is dirty, False otherwise.
"""
return self.git_repo.is_dirty(untracked_files=True)
@property
def has_local_changes(self) -> bool:
"""Whether the git repo has local changes.
A repository has local changes if it is dirty or there are some commits
which have not been pushed yet.
Returns:
True if the git repo has local changes, False otherwise.
Raises:
RuntimeError: If the git repo is in a detached head state.
"""
if self.is_dirty:
return True
self.remote.fetch()
local_commit_object = self.git_repo.head.commit
try:
active_branch = self.git_repo.active_branch
except TypeError:
raise RuntimeError(
"Git repo in detached head state is not allowed."
)
try:
remote_commit_object = self.remote.refs[active_branch.name].commit
except IndexError:
# Branch doesn't exist on remote
return True
return cast("Commit", remote_commit_object) != local_commit_object
@property
def current_commit(self) -> str:
"""The current commit.
Returns:
The current commit sha.
"""
return cast(str, self.git_repo.head.object.hexsha)
current_commit: str
property
readonly
The current commit.
Returns:
Type | Description |
---|---|
str |
The current commit sha. |
git_repo: Repo
property
readonly
The git repo.
Returns:
Type | Description |
---|---|
Repo |
The git repo object of the local git repository. |
has_local_changes: bool
property
readonly
Whether the git repo has local changes.
A repository has local changes if it is dirty or there are some commits which have not been pushed yet.
Returns:
Type | Description |
---|---|
bool |
True if the git repo has local changes, False otherwise. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the git repo is in a detached head state. |
is_dirty: bool
property
readonly
Whether the git repo is dirty.
A repository counts as dirty if it has any untracked or uncommitted changes.
Returns:
Type | Description |
---|---|
bool |
True if the git repo is dirty, False otherwise. |
remote: Remote
property
readonly
The git remote.
Returns:
Type | Description |
---|---|
Remote |
The remote of the git repo object of the local git repository. |
root: str
property
readonly
The root of the git repo.
Returns:
Type | Description |
---|---|
str |
The root of the git repo. |
__init__(self, code_repository_id, git_repo, remote_name)
special
Initializes a local git repository context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository. |
required |
git_repo |
Repo |
The git repo. |
required |
remote_name |
str |
Name of the remote. |
required |
Source code in zenml/code_repositories/git/local_git_repository_context.py
def __init__(
self, code_repository_id: UUID, git_repo: "Repo", remote_name: str
):
"""Initializes a local git repository context.
Args:
code_repository_id: The ID of the code repository.
git_repo: The git repo.
remote_name: Name of the remote.
"""
super().__init__(code_repository_id=code_repository_id)
self._git_repo = git_repo
self._remote = git_repo.remote(name=remote_name)
at(path, code_repository_id, remote_url_validation_callback)
classmethod
Returns a local git repository at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
The path to the local git repository. |
required |
code_repository_id |
UUID |
The ID of the code repository. |
required |
remote_url_validation_callback |
Callable[[str], bool] |
A callback that validates the remote URL of the git repository. |
required |
Returns:
Type | Description |
---|---|
Optional[LocalGitRepositoryContext] |
A local git repository if the path is a valid git repository and the remote URL is valid, otherwise None. |
Source code in zenml/code_repositories/git/local_git_repository_context.py
@classmethod
def at(
cls,
path: str,
code_repository_id: UUID,
remote_url_validation_callback: Callable[[str], bool],
) -> Optional["LocalGitRepositoryContext"]:
"""Returns a local git repository at the given path.
Args:
path: The path to the local git repository.
code_repository_id: The ID of the code repository.
remote_url_validation_callback: A callback that validates the
remote URL of the git repository.
Returns:
A local git repository if the path is a valid git repository
and the remote URL is valid, otherwise None.
"""
try:
# These imports fail when git is not installed on the machine
from git.exc import InvalidGitRepositoryError
from git.repo.base import Repo
except ImportError:
return None
try:
git_repo = Repo(path=path, search_parent_directories=True)
except InvalidGitRepositoryError:
return None
remote_name = None
for remote in git_repo.remotes:
if remote_url_validation_callback(remote.url):
remote_name = remote.name
break
if not remote_name:
return None
return LocalGitRepositoryContext(
code_repository_id=code_repository_id,
git_repo=git_repo,
remote_name=remote_name,
)
local_repository_context
Base class for local code repository contexts.
LocalRepositoryContext (ABC)
Base class for local repository contexts.
This class is used to represent a local repository. It is used to track the current state of the repository and to provide information about the repository, such as the root path, the current commit, and whether the repository is dirty.
Source code in zenml/code_repositories/local_repository_context.py
class LocalRepositoryContext(ABC):
"""Base class for local repository contexts.
This class is used to represent a local repository. It is used
to track the current state of the repository and to provide
information about the repository, such as the root path, the current
commit, and whether the repository is dirty.
"""
def __init__(self, code_repository_id: UUID) -> None:
"""Initializes a local repository context.
Args:
code_repository_id: The ID of the code repository.
"""
self._code_repository_id = code_repository_id
@property
def code_repository_id(self) -> UUID:
"""Returns the ID of the code repository.
Returns:
The ID of the code repository.
"""
return self._code_repository_id
@property
@abstractmethod
def root(self) -> str:
"""Returns the root path of the local repository.
Returns:
The root path of the local repository.
"""
pass
@property
@abstractmethod
def is_dirty(self) -> bool:
"""Returns whether the local repository is dirty.
A repository counts as dirty if it has any untracked or uncommitted
changes.
Returns:
Whether the local repository is dirty.
"""
pass
@property
@abstractmethod
def has_local_changes(self) -> bool:
"""Returns whether the local repository has local changes.
A repository has local changes if it is dirty or there are some commits
which have not been pushed yet.
Returns:
Whether the local repository has local changes.
"""
pass
@property
@abstractmethod
def current_commit(self) -> str:
"""Returns the current commit of the local repository.
Returns:
The current commit of the local repository.
"""
pass
code_repository_id: UUID
property
readonly
Returns the ID of the code repository.
Returns:
Type | Description |
---|---|
UUID |
The ID of the code repository. |
current_commit: str
property
readonly
Returns the current commit of the local repository.
Returns:
Type | Description |
---|---|
str |
The current commit of the local repository. |
has_local_changes: bool
property
readonly
Returns whether the local repository has local changes.
A repository has local changes if it is dirty or there are some commits which have not been pushed yet.
Returns:
Type | Description |
---|---|
bool |
Whether the local repository has local changes. |
is_dirty: bool
property
readonly
Returns whether the local repository is dirty.
A repository counts as dirty if it has any untracked or uncommitted changes.
Returns:
Type | Description |
---|---|
bool |
Whether the local repository is dirty. |
root: str
property
readonly
Returns the root path of the local repository.
Returns:
Type | Description |
---|---|
str |
The root path of the local repository. |
__init__(self, code_repository_id)
special
Initializes a local repository context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
code_repository_id |
UUID |
The ID of the code repository. |
required |
Source code in zenml/code_repositories/local_repository_context.py
def __init__(self, code_repository_id: UUID) -> None:
"""Initializes a local repository context.
Args:
code_repository_id: The ID of the code repository.
"""
self._code_repository_id = code_repository_id