Github
zenml.integrations.github
special
Initialization of the GitHub ZenML integration.
GitHubIntegration (Integration)
Definition of GitHub integration for ZenML.
Source code in zenml/integrations/github/__init__.py
class GitHubIntegration(Integration):
"""Definition of GitHub integration for ZenML."""
NAME = GITHUB
REQUIREMENTS: List[str] = ["pygithub"]
code_repositories
special
Initialization of the ZenML GitHub code repository.
github_code_repository
GitHub code repository.
GitHubCodeRepository (BaseCodeRepository)
GitHub code repository.
Source code in zenml/integrations/github/code_repositories/github_code_repository.py
class GitHubCodeRepository(BaseCodeRepository):
"""GitHub code repository."""
@property
def config(self) -> GitHubCodeRepositoryConfig:
"""Returns the `GitHubCodeRepositoryConfig` config.
Returns:
The configuration.
"""
return GitHubCodeRepositoryConfig(**self._config)
@property
def github_repo(self) -> Repository:
"""The GitHub repository object from the GitHub API.
Returns:
The GitHub repository.
"""
return self._github_session.get_repo(
f"{self.config.owner}/{self.config.repository}"
)
def login(
self,
) -> None:
"""Logs in to GitHub using the token provided in the config.
Raises:
RuntimeError: If the login fails.
"""
try:
self._github_session = Github(self.config.token)
user = self._github_session.get_user().login
logger.debug(f"Logged in as {user}")
except Exception as e:
raise RuntimeError(f"An error occurred while logging in: {str(e)}")
def download_files(
self, commit: str, directory: str, repo_sub_directory: Optional[str]
) -> None:
"""Downloads files from a commit to a local directory.
Args:
commit: The commit to download.
directory: The directory to download to.
repo_sub_directory: The sub directory to download from.
Raises:
RuntimeError: If the repository sub directory is invalid.
"""
contents = self.github_repo.get_contents(
repo_sub_directory or "", ref=commit
)
if not isinstance(contents, List):
raise RuntimeError("Invalid repository subdirectory.")
os.makedirs(directory, exist_ok=True)
for content in contents:
local_path = os.path.join(directory, content.name)
if content.type == "dir":
self.download_files(
commit=commit,
directory=local_path,
repo_sub_directory=content.path,
)
else:
try:
with open(local_path, "wb") as f:
f.write(content.decoded_content)
except (GithubException, IOError) as e:
logger.error("Error processing %s: %s", content.path, e)
def get_local_context(self, path: str) -> Optional[LocalRepositoryContext]:
"""Gets the local repository context.
Args:
path: The path to the local repository.
Returns:
The local repository context.
"""
return LocalGitRepositoryContext.at(
path=path,
code_repository_id=self.id,
remote_url_validation_callback=self.check_remote_url,
)
def check_remote_url(self, url: str) -> bool:
"""Checks whether the remote url matches the code repository.
Args:
url: The remote url.
Returns:
Whether the remote url is correct.
"""
https_url = f"https://{self.config.host}/{self.config.owner}/{self.config.repository}.git"
if url == https_url:
return True
ssh_regex = re.compile(
f".*@{self.config.host}:{self.config.owner}/{self.config.repository}.git"
)
if ssh_regex.fullmatch(url):
return True
return False
config: GitHubCodeRepositoryConfig
property
readonly
Returns the GitHubCodeRepositoryConfig
config.
Returns:
Type | Description |
---|---|
GitHubCodeRepositoryConfig |
The configuration. |
github_repo: Repository
property
readonly
The GitHub repository object from the GitHub API.
Returns:
Type | Description |
---|---|
Repository |
The GitHub repository. |
check_remote_url(self, url)
Checks whether the remote url matches the code repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The remote url. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the remote url is correct. |
Source code in zenml/integrations/github/code_repositories/github_code_repository.py
def check_remote_url(self, url: str) -> bool:
"""Checks whether the remote url matches the code repository.
Args:
url: The remote url.
Returns:
Whether the remote url is correct.
"""
https_url = f"https://{self.config.host}/{self.config.owner}/{self.config.repository}.git"
if url == https_url:
return True
ssh_regex = re.compile(
f".*@{self.config.host}:{self.config.owner}/{self.config.repository}.git"
)
if ssh_regex.fullmatch(url):
return True
return False
download_files(self, commit, directory, repo_sub_directory)
Downloads files from a commit to a local directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
commit |
str |
The commit to download. |
required |
directory |
str |
The directory to download to. |
required |
repo_sub_directory |
Optional[str] |
The sub directory to download from. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the repository sub directory is invalid. |
Source code in zenml/integrations/github/code_repositories/github_code_repository.py
def download_files(
self, commit: str, directory: str, repo_sub_directory: Optional[str]
) -> None:
"""Downloads files from a commit to a local directory.
Args:
commit: The commit to download.
directory: The directory to download to.
repo_sub_directory: The sub directory to download from.
Raises:
RuntimeError: If the repository sub directory is invalid.
"""
contents = self.github_repo.get_contents(
repo_sub_directory or "", ref=commit
)
if not isinstance(contents, List):
raise RuntimeError("Invalid repository subdirectory.")
os.makedirs(directory, exist_ok=True)
for content in contents:
local_path = os.path.join(directory, content.name)
if content.type == "dir":
self.download_files(
commit=commit,
directory=local_path,
repo_sub_directory=content.path,
)
else:
try:
with open(local_path, "wb") as f:
f.write(content.decoded_content)
except (GithubException, IOError) as e:
logger.error("Error processing %s: %s", content.path, e)
get_local_context(self, path)
Gets the local repository context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
The path to the local repository. |
required |
Returns:
Type | Description |
---|---|
Optional[zenml.code_repositories.local_repository_context.LocalRepositoryContext] |
The local repository context. |
Source code in zenml/integrations/github/code_repositories/github_code_repository.py
def get_local_context(self, path: str) -> Optional[LocalRepositoryContext]:
"""Gets the local repository context.
Args:
path: The path to the local repository.
Returns:
The local repository context.
"""
return LocalGitRepositoryContext.at(
path=path,
code_repository_id=self.id,
remote_url_validation_callback=self.check_remote_url,
)
login(self)
Logs in to GitHub using the token provided in the config.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the login fails. |
Source code in zenml/integrations/github/code_repositories/github_code_repository.py
def login(
self,
) -> None:
"""Logs in to GitHub using the token provided in the config.
Raises:
RuntimeError: If the login fails.
"""
try:
self._github_session = Github(self.config.token)
user = self._github_session.get_user().login
logger.debug(f"Logged in as {user}")
except Exception as e:
raise RuntimeError(f"An error occurred while logging in: {str(e)}")
GitHubCodeRepositoryConfig (BaseCodeRepositoryConfig)
pydantic-model
Config for GitHub code repositories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
The URL of the GitHub instance. |
required | |
owner |
The owner of the repository. |
required | |
repository |
The name of the repository. |
required | |
host |
The host of the repository. |
required | |
token |
The token to access the repository. |
required |
Source code in zenml/integrations/github/code_repositories/github_code_repository.py
class GitHubCodeRepositoryConfig(BaseCodeRepositoryConfig):
"""Config for GitHub code repositories.
Args:
url: The URL of the GitHub instance.
owner: The owner of the repository.
repository: The name of the repository.
host: The host of the repository.
token: The token to access the repository.
"""
url: Optional[str]
owner: str
repository: str
host: Optional[str] = "github.com"
token: str = SecretField()