Utils
zenml.utils
special
Initialization of the utils module.
The utils
module contains utility functions handling analytics, reading and
writing YAML data as well as other general purpose functions.
analytics_utils
Analytics code for ZenML.
AnalyticsContext
Context manager for analytics.
Source code in zenml/utils/analytics_utils.py
class AnalyticsContext:
"""Context manager for analytics."""
def __init__(self) -> None:
"""Context manager for analytics.
Use this as a context manager to ensure that analytics are initialized
properly, only tracked when configured to do so and that any errors
are handled gracefully.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
try:
gc = GlobalConfiguration()
self.analytics_opt_in = gc.analytics_opt_in
self.user_id = str(gc.user_id)
# That means user opted out of analytics
if not gc.analytics_opt_in:
return
if analytics.write_key is None:
analytics.write_key = get_segment_key()
assert (
analytics.write_key is not None
), "Analytics key not set but trying to make telemetry call."
# Set this to 1 to avoid backoff loop
analytics.max_retries = 1
except Exception as e:
self.analytics_opt_in = False
logger.debug(f"Analytics initialization failed: {e}")
def __enter__(self) -> "AnalyticsContext":
"""Enter context manager.
Returns:
Self.
"""
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Exit context manager.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Returns:
True if exception was handled, False otherwise.
"""
if exc_val is not None:
logger.debug("Sending telemetry data failed: {exc_val}")
# We should never fail main thread
return True
def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
"""Identify the user.
Args:
traits: Traits of the user.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.identify(self.user_id, traits)
logger.debug(f"User data sent: User: {self.user_id},{traits}")
return True
def group(
self,
group: Union[str, AnalyticsGroup],
group_id: str,
traits: Optional[Dict[str, Any]] = None,
) -> bool:
"""Group the user.
Args:
group: Group to which the user belongs.
group_id: Group ID.
traits: Traits of the group.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
if isinstance(group, AnalyticsGroup):
group = group.value
if traits is None:
traits = {}
traits.update(
{
"group_id": group_id,
}
)
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.group(self.user_id, group_id, traits=traits)
logger.debug(
f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
f"{group_id}, Metadata: {traits}"
)
return True
def track(
self,
event: Union[str, AnalyticsEvent],
properties: Optional[Dict[str, Any]] = None,
) -> bool:
"""Track an event.
Args:
event: Event to track.
properties: Event properties.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
if isinstance(event, AnalyticsEvent):
event = event.value
if properties is None:
properties = {}
logger.debug(
f"Attempting analytics: User: {self.user_id}, "
f"Event: {event},"
f"Metadata: {properties}"
)
if not self.analytics_opt_in and event not in {
AnalyticsEvent.OPT_OUT_ANALYTICS,
AnalyticsEvent.OPT_IN_ANALYTICS,
}:
return False
# add basics
properties.update(Environment.get_system_info())
properties.update(
{
"environment": get_environment(),
"python_version": Environment.python_version(),
"version": __version__,
}
)
gc = GlobalConfiguration()
# avoid initializing the store in the analytics, to not create an
# infinite loop
if gc._zen_store is not None:
zen_store = gc.zen_store
user = zen_store.get_user()
if "client_id" not in properties:
properties["client_id"] = self.user_id
if "user_id" not in properties:
properties["user_id"] = str(user.id)
if (
zen_store.type == StoreType.REST
and "server_id" not in properties
):
server_info = zen_store.get_store_info()
properties.update(
{
"user_id": str(user.id),
"server_id": str(server_info.id),
"server_deployment": str(server_info.deployment_type),
"database_type": str(server_info.database_type),
}
)
for k, v in properties.items():
if isinstance(v, UUID):
properties[k] = str(v)
analytics.track(self.user_id, event, properties)
logger.debug(
f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
f"{properties}"
)
return True
__enter__(self)
special
Enter context manager.
Returns:
Type | Description |
---|---|
AnalyticsContext |
Self. |
Source code in zenml/utils/analytics_utils.py
def __enter__(self) -> "AnalyticsContext":
"""Enter context manager.
Returns:
Self.
"""
return self
__exit__(self, exc_type, exc_val, exc_tb)
special
Exit context manager.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
exc_type |
Optional[Type[BaseException]] |
Exception type. |
required |
exc_val |
Optional[BaseException] |
Exception value. |
required |
exc_tb |
Optional[traceback] |
Exception traceback. |
required |
Returns:
Type | Description |
---|---|
bool |
True if exception was handled, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Exit context manager.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Returns:
True if exception was handled, False otherwise.
"""
if exc_val is not None:
logger.debug("Sending telemetry data failed: {exc_val}")
# We should never fail main thread
return True
__init__(self)
special
Context manager for analytics.
Use this as a context manager to ensure that analytics are initialized properly, only tracked when configured to do so and that any errors are handled gracefully.
Source code in zenml/utils/analytics_utils.py
def __init__(self) -> None:
"""Context manager for analytics.
Use this as a context manager to ensure that analytics are initialized
properly, only tracked when configured to do so and that any errors
are handled gracefully.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
try:
gc = GlobalConfiguration()
self.analytics_opt_in = gc.analytics_opt_in
self.user_id = str(gc.user_id)
# That means user opted out of analytics
if not gc.analytics_opt_in:
return
if analytics.write_key is None:
analytics.write_key = get_segment_key()
assert (
analytics.write_key is not None
), "Analytics key not set but trying to make telemetry call."
# Set this to 1 to avoid backoff loop
analytics.max_retries = 1
except Exception as e:
self.analytics_opt_in = False
logger.debug(f"Analytics initialization failed: {e}")
group(self, group, group_id, traits=None)
Group the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
group |
Union[str, zenml.utils.analytics_utils.AnalyticsGroup] |
Group to which the user belongs. |
required |
group_id |
str |
Group ID. |
required |
traits |
Optional[Dict[str, Any]] |
Traits of the group. |
None |
Returns:
Type | Description |
---|---|
bool |
True if tracking information was sent, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def group(
self,
group: Union[str, AnalyticsGroup],
group_id: str,
traits: Optional[Dict[str, Any]] = None,
) -> bool:
"""Group the user.
Args:
group: Group to which the user belongs.
group_id: Group ID.
traits: Traits of the group.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
if isinstance(group, AnalyticsGroup):
group = group.value
if traits is None:
traits = {}
traits.update(
{
"group_id": group_id,
}
)
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.group(self.user_id, group_id, traits=traits)
logger.debug(
f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
f"{group_id}, Metadata: {traits}"
)
return True
identify(self, traits=None)
Identify the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
traits |
Optional[Dict[str, Any]] |
Traits of the user. |
None |
Returns:
Type | Description |
---|---|
bool |
True if tracking information was sent, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
"""Identify the user.
Args:
traits: Traits of the user.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.identify(self.user_id, traits)
logger.debug(f"User data sent: User: {self.user_id},{traits}")
return True
track(self, event, properties=None)
Track an event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Event to track. |
required |
properties |
Optional[Dict[str, Any]] |
Event properties. |
None |
Returns:
Type | Description |
---|---|
bool |
True if tracking information was sent, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def track(
self,
event: Union[str, AnalyticsEvent],
properties: Optional[Dict[str, Any]] = None,
) -> bool:
"""Track an event.
Args:
event: Event to track.
properties: Event properties.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
if isinstance(event, AnalyticsEvent):
event = event.value
if properties is None:
properties = {}
logger.debug(
f"Attempting analytics: User: {self.user_id}, "
f"Event: {event},"
f"Metadata: {properties}"
)
if not self.analytics_opt_in and event not in {
AnalyticsEvent.OPT_OUT_ANALYTICS,
AnalyticsEvent.OPT_IN_ANALYTICS,
}:
return False
# add basics
properties.update(Environment.get_system_info())
properties.update(
{
"environment": get_environment(),
"python_version": Environment.python_version(),
"version": __version__,
}
)
gc = GlobalConfiguration()
# avoid initializing the store in the analytics, to not create an
# infinite loop
if gc._zen_store is not None:
zen_store = gc.zen_store
user = zen_store.get_user()
if "client_id" not in properties:
properties["client_id"] = self.user_id
if "user_id" not in properties:
properties["user_id"] = str(user.id)
if (
zen_store.type == StoreType.REST
and "server_id" not in properties
):
server_info = zen_store.get_store_info()
properties.update(
{
"user_id": str(user.id),
"server_id": str(server_info.id),
"server_deployment": str(server_info.deployment_type),
"database_type": str(server_info.database_type),
}
)
for k, v in properties.items():
if isinstance(v, UUID):
properties[k] = str(v)
analytics.track(self.user_id, event, properties)
logger.debug(
f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
f"{properties}"
)
return True
AnalyticsEvent (str, Enum)
Enum of events to track in segment.
Source code in zenml/utils/analytics_utils.py
class AnalyticsEvent(str, Enum):
"""Enum of events to track in segment."""
# Pipelines
RUN_PIPELINE = "Pipeline run"
GET_PIPELINES = "Pipelines fetched"
GET_PIPELINE = "Pipeline fetched"
CREATE_PIPELINE = "Pipeline created"
UPDATE_PIPELINE = "Pipeline updated"
DELETE_PIPELINE = "Pipeline deleted"
# Repo
INITIALIZE_REPO = "ZenML initialized"
CONNECT_REPOSITORY = "Repository connected"
UPDATE_REPOSITORY = "Repository updated"
DELETE_REPOSITORY = "Repository deleted"
# Template
GENERATE_TEMPLATE = "Template generated"
# Zen store
INITIALIZED_STORE = "Store initialized"
# Components
REGISTERED_STACK_COMPONENT = "Stack component registered"
UPDATED_STACK_COMPONENT = "Stack component updated"
COPIED_STACK_COMPONENT = "Stack component copied"
DELETED_STACK_COMPONENT = "Stack component copied"
# Stack
REGISTERED_STACK = "Stack registered"
REGISTERED_DEFAULT_STACK = "Default stack registered"
SET_STACK = "Stack set"
UPDATED_STACK = "Stack updated"
COPIED_STACK = "Stack copied"
IMPORT_STACK = "Stack imported"
EXPORT_STACK = "Stack exported"
DELETED_STACK = "Stack deleted"
# Model Deployment
MODEL_DEPLOYED = "Model deployed"
# Analytics opt in and out
OPT_IN_ANALYTICS = "Analytics opt-in"
OPT_OUT_ANALYTICS = "Analytics opt-out"
OPT_IN_OUT_EMAIL = "Response for Email prompt"
# Examples
RUN_ZENML_GO = "ZenML go"
RUN_EXAMPLE = "Example run"
PULL_EXAMPLE = "Example pull"
# Integrations
INSTALL_INTEGRATION = "Integration installed"
# Users
CREATED_USER = "User created"
CREATED_DEFAULT_USER = "Default user created"
UPDATED_USER = "User updated"
DELETED_USER = "User deleted"
# Teams
CREATED_TEAM = "Team created"
UPDATED_TEAM = "Team updated"
DELETED_TEAM = "Team deleted"
# Workspaces
CREATED_WORKSPACE = "Workspace created"
CREATED_DEFAULT_WORKSPACE = "Default workspace created"
UPDATED_WORKSPACE = "Workspace updated"
DELETED_WORKSPACE = "Workspace deleted"
SET_WORKSPACE = "Workspace set"
# Role
CREATED_ROLE = "Role created"
CREATED_DEFAULT_ROLES = "Default roles created"
UPDATED_ROLE = "Role updated"
DELETED_ROLE = "Role deleted"
# Flavor
CREATED_FLAVOR = "Flavor created"
UPDATED_FLAVOR = "Flavor updated"
DELETED_FLAVOR = "Flavor deleted"
# Secret
CREATED_SECRET = "Secret created"
UPDATED_SECRET = "Secret updated"
DELETED_SECRET = "Secret deleted"
# Test event
EVENT_TEST = "Test event"
# Stack recipes
PULL_STACK_RECIPE = "Stack recipes pulled"
RUN_STACK_RECIPE = "Stack recipe created"
DESTROY_STACK_RECIPE = "Stack recipe destroyed"
# ZenML server events
ZENML_SERVER_STARTED = "ZenML server started"
ZENML_SERVER_STOPPED = "ZenML server stopped"
ZENML_SERVER_CONNECTED = "ZenML server connected"
ZENML_SERVER_DEPLOYED = "ZenML server deployed"
ZENML_SERVER_DESTROYED = "ZenML server destroyed"
AnalyticsGroup (str, Enum)
Enum of event groups to track in segment.
Source code in zenml/utils/analytics_utils.py
class AnalyticsGroup(str, Enum):
"""Enum of event groups to track in segment."""
ZENML_SERVER_GROUP = "ZenML server group"
AnalyticsTrackedModelMixin (BaseModel)
pydantic-model
Mixin for models that are tracked through analytics events.
Classes that have information tracked in analytics events can inherit
from this mixin and implement the abstract methods. The @track
decorator
will detect function arguments and return values that inherit from this
class and will include the ANALYTICS_FIELDS
attributes as
tracking metadata.
Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackedModelMixin(BaseModel):
"""Mixin for models that are tracked through analytics events.
Classes that have information tracked in analytics events can inherit
from this mixin and implement the abstract methods. The `@track` decorator
will detect function arguments and return values that inherit from this
class and will include the `ANALYTICS_FIELDS` attributes as
tracking metadata.
"""
ANALYTICS_FIELDS: ClassVar[List[str]] = []
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Get the analytics metadata for the model.
Returns:
Dict of analytics metadata.
"""
metadata = {}
for field_name in self.ANALYTICS_FIELDS:
metadata[field_name] = getattr(self, field_name, None)
return metadata
get_analytics_metadata(self)
Get the analytics metadata for the model.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Dict of analytics metadata. |
Source code in zenml/utils/analytics_utils.py
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Get the analytics metadata for the model.
Returns:
Dict of analytics metadata.
"""
metadata = {}
for field_name in self.ANALYTICS_FIELDS:
metadata[field_name] = getattr(self, field_name, None)
return metadata
AnalyticsTrackerMixin (ABC)
Abstract base class for analytics trackers.
Use this as a mixin for classes that have methods decorated with
@track
to add global control over how analytics are tracked. The decorator
will detect that the class has this mixin and will call the class
track_event
method.
Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackerMixin(ABC):
"""Abstract base class for analytics trackers.
Use this as a mixin for classes that have methods decorated with
`@track` to add global control over how analytics are tracked. The decorator
will detect that the class has this mixin and will call the class
`track_event` method.
"""
@abstractmethod
def track_event(
self,
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]],
) -> None:
"""Track an event.
Args:
event: Event to track.
metadata: Metadata to track.
"""
track_event(self, event, metadata)
Track an event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Event to track. |
required |
metadata |
Optional[Dict[str, Any]] |
Metadata to track. |
required |
Source code in zenml/utils/analytics_utils.py
@abstractmethod
def track_event(
self,
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]],
) -> None:
"""Track an event.
Args:
event: Event to track.
metadata: Metadata to track.
"""
event_handler
Context handler to enable tracking the success status of an event.
Source code in zenml/utils/analytics_utils.py
class event_handler(object):
"""Context handler to enable tracking the success status of an event."""
def __init__(
self, event: AnalyticsEvent, metadata: Optional[Dict[str, Any]] = None
):
"""Initialization of the context manager.
Args:
event: The type of the analytics event
metadata: The metadata of the event.
"""
self.event: AnalyticsEvent = event
self.metadata: Dict[str, Any] = metadata or {}
self.tracker: Optional[AnalyticsTrackerMixin] = None
def __enter__(self) -> "event_handler":
"""Enter function of the event handler.
Returns:
the handler instance.
"""
return self
def __exit__(
self,
type_: Optional[Any],
value: Optional[Any],
traceback: Optional[Any],
) -> Any:
"""Exit function of the event handler.
Checks whether there was a traceback and updates the metadata
accordingly. Following the check, it calls the function to track the
event.
Args:
type_: The class of the exception
value: The instance of the exception
traceback: The traceback of the exception
"""
if traceback is not None:
self.metadata.update({"event_success": False})
else:
self.metadata.update({"event_success": True})
if type_ is not None:
self.metadata.update({"event_error_type": type_.__name__})
if self.tracker:
self.tracker.track_event(self.event, self.metadata)
else:
track_event(self.event, self.metadata)
__enter__(self)
special
Enter function of the event handler.
Returns:
Type | Description |
---|---|
event_handler |
the handler instance. |
Source code in zenml/utils/analytics_utils.py
def __enter__(self) -> "event_handler":
"""Enter function of the event handler.
Returns:
the handler instance.
"""
return self
__exit__(self, type_, value, traceback)
special
Exit function of the event handler.
Checks whether there was a traceback and updates the metadata accordingly. Following the check, it calls the function to track the event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
type_ |
Optional[Any] |
The class of the exception |
required |
value |
Optional[Any] |
The instance of the exception |
required |
traceback |
Optional[Any] |
The traceback of the exception |
required |
Source code in zenml/utils/analytics_utils.py
def __exit__(
self,
type_: Optional[Any],
value: Optional[Any],
traceback: Optional[Any],
) -> Any:
"""Exit function of the event handler.
Checks whether there was a traceback and updates the metadata
accordingly. Following the check, it calls the function to track the
event.
Args:
type_: The class of the exception
value: The instance of the exception
traceback: The traceback of the exception
"""
if traceback is not None:
self.metadata.update({"event_success": False})
else:
self.metadata.update({"event_success": True})
if type_ is not None:
self.metadata.update({"event_error_type": type_.__name__})
if self.tracker:
self.tracker.track_event(self.event, self.metadata)
else:
track_event(self.event, self.metadata)
__init__(self, event, metadata=None)
special
Initialization of the context manager.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
AnalyticsEvent |
The type of the analytics event |
required |
metadata |
Optional[Dict[str, Any]] |
The metadata of the event. |
None |
Source code in zenml/utils/analytics_utils.py
def __init__(
self, event: AnalyticsEvent, metadata: Optional[Dict[str, Any]] = None
):
"""Initialization of the context manager.
Args:
event: The type of the analytics event
metadata: The metadata of the event.
"""
self.event: AnalyticsEvent = event
self.metadata: Dict[str, Any] = metadata or {}
self.tracker: Optional[AnalyticsTrackerMixin] = None
get_segment_key()
Get key for authorizing to Segment backend.
Returns:
Type | Description |
---|---|
str |
Segment key as a string. |
Source code in zenml/utils/analytics_utils.py
def get_segment_key() -> str:
"""Get key for authorizing to Segment backend.
Returns:
Segment key as a string.
"""
if IS_DEBUG_ENV:
return SEGMENT_KEY_DEV
else:
return SEGMENT_KEY_PROD
identify_group(group, group_id, group_metadata=None)
Attach metadata to a segment group.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
group |
Union[str, zenml.utils.analytics_utils.AnalyticsGroup] |
Group to track. |
required |
group_id |
str |
ID of the group. |
required |
group_metadata |
Optional[Dict[str, Any]] |
Metadata to attach to the group. |
None |
Returns:
Type | Description |
---|---|
bool |
True if event is sent successfully, False is not. |
Source code in zenml/utils/analytics_utils.py
def identify_group(
group: Union[str, AnalyticsGroup],
group_id: str,
group_metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""Attach metadata to a segment group.
Args:
group: Group to track.
group_id: ID of the group.
group_metadata: Metadata to attach to the group.
Returns:
True if event is sent successfully, False is not.
"""
with AnalyticsContext() as analytics:
return analytics.group(group, group_id, traits=group_metadata)
return False
identify_user(user_metadata=None)
Attach metadata to user directly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_metadata |
Optional[Dict[str, Any]] |
Dict of metadata to attach to the user. |
None |
Returns:
Type | Description |
---|---|
bool |
True if event is sent successfully, False is not. |
Source code in zenml/utils/analytics_utils.py
def identify_user(user_metadata: Optional[Dict[str, Any]] = None) -> bool:
"""Attach metadata to user directly.
Args:
user_metadata: Dict of metadata to attach to the user.
Returns:
True if event is sent successfully, False is not.
"""
with AnalyticsContext() as analytics:
if user_metadata is None:
return False
return analytics.identify(traits=user_metadata)
return False
parametrized(dec)
A meta-decorator, that is, a decorator for decorators.
As a decorator is a function, it actually works as a regular decorator with arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dec |
Callable[..., Callable[..., Any]] |
Decorator to be applied to the function. |
required |
Returns:
Type | Description |
---|---|
Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]] |
Decorator that applies the given decorator to the function. |
Source code in zenml/utils/analytics_utils.py
def parametrized(
dec: Callable[..., Callable[..., Any]]
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]]:
"""A meta-decorator, that is, a decorator for decorators.
As a decorator is a function, it actually works as a regular decorator
with arguments.
Args:
dec: Decorator to be applied to the function.
Returns:
Decorator that applies the given decorator to the function.
"""
def layer(
*args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Internal layer.
Args:
*args: Arguments to be passed to the decorator.
**kwargs: Keyword arguments to be passed to the decorator.
Returns:
Decorator that applies the given decorator to the function.
"""
def repl(f: Callable[..., Any]) -> Callable[..., Any]:
"""Internal REPL.
Args:
f: Function to be decorated.
Returns:
Decorated function.
"""
return dec(f, *args, **kwargs)
return repl
return layer
track(*args, **kwargs)
Internal layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Arguments to be passed to the decorator. |
() |
**kwargs |
Any |
Keyword arguments to be passed to the decorator. |
{} |
Returns:
Type | Description |
---|---|
Callable[[Callable[..., Any]], Callable[..., Any]] |
Decorator that applies the given decorator to the function. |
Source code in zenml/utils/analytics_utils.py
def layer(
*args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Internal layer.
Args:
*args: Arguments to be passed to the decorator.
**kwargs: Keyword arguments to be passed to the decorator.
Returns:
Decorator that applies the given decorator to the function.
"""
def repl(f: Callable[..., Any]) -> Callable[..., Any]:
"""Internal REPL.
Args:
f: Function to be decorated.
Returns:
Decorated function.
"""
return dec(f, *args, **kwargs)
return repl
track_event(event, metadata=None)
Track segment event if user opted-in.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Name of event to track in segment. |
required |
metadata |
Optional[Dict[str, Any]] |
Dict of metadata to track. |
None |
Returns:
Type | Description |
---|---|
bool |
True if event is sent successfully, False is not. |
Source code in zenml/utils/analytics_utils.py
def track_event(
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""Track segment event if user opted-in.
Args:
event: Name of event to track in segment.
metadata: Dict of metadata to track.
Returns:
True if event is sent successfully, False is not.
"""
if metadata is None:
metadata = {}
metadata.setdefault("event_success", True)
with AnalyticsContext() as analytics:
return analytics.track(event, metadata)
return False
daemon
Utility functions to start/stop daemon processes.
This is only implemented for UNIX systems and therefore doesn't work on Windows. Based on https://www.jejik.com/articles/2007/02/a_simple_unix_linux_daemon_in_python/
check_if_daemon_is_running(pid_file)
Checks whether a daemon process indicated by the PID file is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
Path to file containing the PID of the daemon process to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the daemon process is running, otherwise False. |
Source code in zenml/utils/daemon.py
def check_if_daemon_is_running(pid_file: str) -> bool:
"""Checks whether a daemon process indicated by the PID file is running.
Args:
pid_file: Path to file containing the PID of the daemon
process to check.
Returns:
True if the daemon process is running, otherwise False.
"""
return get_daemon_pid_if_running(pid_file) is not None
daemonize(pid_file, log_file=None, working_directory='/')
Decorator that executes the decorated function as a daemon process.
Use this decorator to easily transform any function into a daemon process.
For example,
import time
from zenml.utils.daemon import daemonize
@daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
def sleeping_daemon(period: int) -> None:
print(f"I'm a daemon! I will sleep for {period} seconds.")
time.sleep(period)
print("Done sleeping, flying away.")
sleeping_daemon(period=30)
print("I'm the daemon's parent!.")
time.sleep(10) # just to prove that the daemon is running in parallel
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
a file where the PID of the daemon process will be stored. |
required |
log_file |
Optional[str] |
file where stdout and stderr are redirected for the daemon process. If not supplied, the daemon will be silenced (i.e. have its stdout/stderr redirected to /dev/null). |
None |
working_directory |
str |
working directory for the daemon process, defaults to the root directory. |
'/' |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
Decorated function that, when called, will detach from the current process and continue executing in the background, as a daemon process. |
Source code in zenml/utils/daemon.py
def daemonize(
pid_file: str,
log_file: Optional[str] = None,
working_directory: str = "/",
) -> Callable[[F], F]:
"""Decorator that executes the decorated function as a daemon process.
Use this decorator to easily transform any function into a daemon
process.
For example,
```python
import time
from zenml.utils.daemon import daemonize
@daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
def sleeping_daemon(period: int) -> None:
print(f"I'm a daemon! I will sleep for {period} seconds.")
time.sleep(period)
print("Done sleeping, flying away.")
sleeping_daemon(period=30)
print("I'm the daemon's parent!.")
time.sleep(10) # just to prove that the daemon is running in parallel
```
Args:
pid_file: a file where the PID of the daemon process will
be stored.
log_file: file where stdout and stderr are redirected for the daemon
process. If not supplied, the daemon will be silenced (i.e. have
its stdout/stderr redirected to /dev/null).
working_directory: working directory for the daemon process,
defaults to the root directory.
Returns:
Decorated function that, when called, will detach from the current
process and continue executing in the background, as a daemon
process.
"""
def inner_decorator(_func: F) -> F:
def daemon(*args: Any, **kwargs: Any) -> None:
"""Standard daemonization of a process.
Args:
*args: Arguments to be passed to the decorated function.
**kwargs: Keyword arguments to be passed to the decorated
function.
"""
if sys.platform == "win32":
logger.error(
"Daemon functionality is currently not supported on Windows."
)
else:
run_as_daemon(
_func,
log_file=log_file,
pid_file=pid_file,
working_directory=working_directory,
*args,
**kwargs,
)
return cast(F, daemon)
return inner_decorator
get_daemon_pid_if_running(pid_file)
Read and return the PID value from a PID file.
It does this if the daemon process tracked by the PID file is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
Path to file containing the PID of the daemon process to check. |
required |
Returns:
Type | Description |
---|---|
Optional[int] |
The PID of the daemon process if it is running, otherwise None. |
Source code in zenml/utils/daemon.py
def get_daemon_pid_if_running(pid_file: str) -> Optional[int]:
"""Read and return the PID value from a PID file.
It does this if the daemon process tracked by the PID file is running.
Args:
pid_file: Path to file containing the PID of the daemon
process to check.
Returns:
The PID of the daemon process if it is running, otherwise None.
"""
try:
with open(pid_file, "r") as f:
pid = int(f.read().strip())
except (IOError, FileNotFoundError):
logger.debug(
f"Daemon PID file '{pid_file}' does not exist or cannot be read."
)
return None
if not pid or not psutil.pid_exists(pid):
logger.debug(f"Daemon with PID '{pid}' is no longer running.")
return None
logger.debug(f"Daemon with PID '{pid}' is running.")
return pid
run_as_daemon(daemon_function, *args, *, pid_file, log_file=None, working_directory='/', **kwargs)
Runs a function as a daemon process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
daemon_function |
~F |
The function to run as a daemon. |
required |
pid_file |
str |
Path to file in which to store the PID of the daemon process. |
required |
log_file |
Optional[str] |
Optional file to which the daemons stdout/stderr will be redirected to. |
None |
working_directory |
str |
Working directory for the daemon process, defaults to the root directory. |
'/' |
args |
Any |
Positional arguments to pass to the daemon function. |
() |
kwargs |
Any |
Keyword arguments to pass to the daemon function. |
{} |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If the PID file already exists. |
Source code in zenml/utils/daemon.py
def run_as_daemon(
daemon_function: F,
*args: Any,
pid_file: str,
log_file: Optional[str] = None,
working_directory: str = "/",
**kwargs: Any,
) -> None:
"""Runs a function as a daemon process.
Args:
daemon_function: The function to run as a daemon.
pid_file: Path to file in which to store the PID of the daemon
process.
log_file: Optional file to which the daemons stdout/stderr will be
redirected to.
working_directory: Working directory for the daemon process,
defaults to the root directory.
args: Positional arguments to pass to the daemon function.
kwargs: Keyword arguments to pass to the daemon function.
Raises:
FileExistsError: If the PID file already exists.
"""
# convert to absolute path as we will change working directory later
if pid_file:
pid_file = os.path.abspath(pid_file)
if log_file:
log_file = os.path.abspath(log_file)
# create parent directory if necessary
dir_name = os.path.dirname(pid_file)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
# check if PID file exists
if pid_file and os.path.exists(pid_file):
pid = get_daemon_pid_if_running(pid_file)
if pid:
raise FileExistsError(
f"The PID file '{pid_file}' already exists and a daemon "
f"process with the same PID '{pid}' is already running."
f"Please remove the PID file or kill the daemon process "
f"before starting a new daemon."
)
logger.warning(
f"Removing left over PID file '{pid_file}' from a previous "
f"daemon process that didn't shut down correctly."
)
os.remove(pid_file)
# first fork
try:
pid = os.fork()
if pid > 0:
# this is the process that called `run_as_daemon` so we
# wait for the child process to finish to avoid creating
# zombie processes. Then we simply return so the current process
# can continue what it was doing.
os.wait()
return
except OSError as e:
logger.error("Unable to fork (error code: %d)", e.errno)
sys.exit(1)
# decouple from parent environment
os.chdir(working_directory)
os.setsid()
os.umask(0o22)
# second fork
try:
pid = os.fork()
if pid > 0:
# this is the parent of the future daemon process, kill it
# so the daemon gets adopted by the init process.
# we use os._exit here to prevent the inherited code from
# catching the SystemExit exception and doing something else.
os._exit(0)
except OSError as e:
sys.stderr.write(f"Unable to fork (error code: {e.errno})")
# we use os._exit here to prevent the inherited code from
# catching the SystemExit exception and doing something else.
os._exit(1)
# redirect standard file descriptors to devnull (or the given logfile)
devnull = "/dev/null"
if hasattr(os, "devnull"):
devnull = os.devnull
devnull_fd = os.open(devnull, os.O_RDWR)
log_fd = (
os.open(log_file, os.O_CREAT | os.O_RDWR | os.O_APPEND)
if log_file
else None
)
out_fd = log_fd or devnull_fd
try:
os.dup2(devnull_fd, sys.stdin.fileno())
except io.UnsupportedOperation:
# stdin is not a file descriptor
pass
try:
os.dup2(out_fd, sys.stdout.fileno())
except io.UnsupportedOperation:
# stdout is not a file descriptor
pass
try:
os.dup2(out_fd, sys.stderr.fileno())
except io.UnsupportedOperation:
# stderr is not a file descriptor
pass
if pid_file:
# write the PID file
with open(pid_file, "w+") as f:
f.write(f"{os.getpid()}\n")
# register actions in case this process exits/gets killed
def cleanup() -> None:
"""Daemon cleanup."""
sys.stderr.write("Cleanup: terminating children processes...\n")
terminate_children()
if pid_file and os.path.exists(pid_file):
sys.stderr.write(f"Cleanup: removing PID file {pid_file}...\n")
os.remove(pid_file)
sys.stderr.flush()
def sighndl(signum: int, frame: Optional[types.FrameType]) -> None:
"""Daemon signal handler.
Args:
signum: Signal number.
frame: Frame object.
"""
sys.stderr.write(f"Handling signal {signum}...\n")
cleanup()
signal.signal(signal.SIGTERM, sighndl)
signal.signal(signal.SIGINT, sighndl)
atexit.register(cleanup)
# finally run the actual daemon code
daemon_function(*args, **kwargs)
sys.exit(0)
stop_daemon(pid_file)
Stops a daemon process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
Path to file containing the PID of the daemon process to kill. |
required |
Source code in zenml/utils/daemon.py
def stop_daemon(pid_file: str) -> None:
"""Stops a daemon process.
Args:
pid_file: Path to file containing the PID of the daemon process to
kill.
"""
try:
with open(pid_file, "r") as f:
pid = int(f.read().strip())
except (IOError, FileNotFoundError):
logger.warning("Daemon PID file '%s' does not exist.", pid_file)
return
if psutil.pid_exists(pid):
process = psutil.Process(pid)
process.terminate()
else:
logger.warning("PID from '%s' does not exist.", pid_file)
terminate_children()
Terminate all processes that are children of the currently running process.
Source code in zenml/utils/daemon.py
def terminate_children() -> None:
"""Terminate all processes that are children of the currently running process."""
pid = os.getpid()
try:
parent = psutil.Process(pid)
except psutil.Error:
# could not find parent process id
return
children = parent.children(recursive=False)
for p in children:
sys.stderr.write(
f"Terminating child process with PID {p.pid}...\n"
)
p.terminate()
_, alive = psutil.wait_procs(
children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
)
for p in alive:
sys.stderr.write(f"Killing child process with PID {p.pid}...\n")
p.kill()
_, alive = psutil.wait_procs(
children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
)
dashboard_utils
Utility class to help with interacting with the dashboard.
get_run_url(run_name, pipeline_id=None)
Computes a dashboard url to directly view the run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the pipeline run. |
required |
pipeline_id |
Optional[uuid.UUID] |
Optional pipeline_id, to be sent when available. |
None |
Returns:
Type | Description |
---|---|
Optional[str] |
A direct url link to the pipeline run details page. If run does not exist, returns None. |
Source code in zenml/utils/dashboard_utils.py
def get_run_url(
run_name: str, pipeline_id: Optional[UUID] = None
) -> Optional[str]:
"""Computes a dashboard url to directly view the run.
Args:
run_name: Name of the pipeline run.
pipeline_id: Optional pipeline_id, to be sent when available.
Returns:
A direct url link to the pipeline run details page. If run does not exist,
returns None.
"""
# Connected to ZenML Server
client = Client()
if client.zen_store.type != StoreType.REST:
return ""
url = client.zen_store.url
runs = client.depaginate(partial(client.list_runs, name=run_name))
if pipeline_id:
url += f"/workspaces/{client.active_workspace.name}/pipelines/{str(pipeline_id)}/runs"
elif runs:
url += "/runs"
else:
url += "/pipelines/all-runs"
if runs:
url += f"/{runs[0].id}/dag"
return url
print_run_url(run_name, pipeline_id=None)
Logs a dashboard url to directly view the run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the pipeline run. |
required |
pipeline_id |
Optional[uuid.UUID] |
Optional pipeline_id, to be sent when available. |
None |
Source code in zenml/utils/dashboard_utils.py
def print_run_url(run_name: str, pipeline_id: Optional[UUID] = None) -> None:
"""Logs a dashboard url to directly view the run.
Args:
run_name: Name of the pipeline run.
pipeline_id: Optional pipeline_id, to be sent when available.
"""
client = Client()
if client.zen_store.type == StoreType.REST:
url = get_run_url(
run_name,
pipeline_id,
)
if url:
logger.info(f"Dashboard URL: {url}")
elif client.zen_store.type == StoreType.SQL:
# Connected to SQL Store Type, we're local
logger.info(
"Pipeline visualization can be seen in the ZenML Dashboard. "
"Run `zenml up` to see your pipeline!"
)
deprecation_utils
Deprecation utilities.
deprecate_pydantic_attributes(*attributes)
Utility function for deprecating and migrating pydantic attributes.
Usage: To use this, you can specify it on any pydantic BaseModel subclass like this (all the deprecated attributes need to be non-required):
from pydantic import BaseModel
from typing import Optional
class MyModel(BaseModel):
deprecated: Optional[int] = None
old_name: Optional[str] = None
new_name: str
_deprecation_validator = deprecate_pydantic_attributes(
"deprecated", ("old_name", "new_name")
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*attributes |
Union[str, Tuple[str, str]] |
List of attributes to deprecate. This is either the name of the attribute to deprecate, or a tuple containing the name of the deprecated attribute and it's replacement. |
() |
Returns:
Type | Description |
---|---|
AnyClassMethod |
Pydantic validator class method to be used on BaseModel subclasses to deprecate or migrate attributes. |
Source code in zenml/utils/deprecation_utils.py
def deprecate_pydantic_attributes(
*attributes: Union[str, Tuple[str, str]]
) -> "AnyClassMethod":
"""Utility function for deprecating and migrating pydantic attributes.
**Usage**:
To use this, you can specify it on any pydantic BaseModel subclass like
this (all the deprecated attributes need to be non-required):
```python
from pydantic import BaseModel
from typing import Optional
class MyModel(BaseModel):
deprecated: Optional[int] = None
old_name: Optional[str] = None
new_name: str
_deprecation_validator = deprecate_pydantic_attributes(
"deprecated", ("old_name", "new_name")
)
```
Args:
*attributes: List of attributes to deprecate. This is either the name
of the attribute to deprecate, or a tuple containing the name of
the deprecated attribute and it's replacement.
Returns:
Pydantic validator class method to be used on BaseModel subclasses
to deprecate or migrate attributes.
"""
@root_validator(pre=True, allow_reuse=True)
def _deprecation_validator(
cls: Type[BaseModel], values: Dict[str, Any]
) -> Dict[str, Any]:
"""Pydantic validator function for deprecating pydantic attributes.
Args:
cls: The class on which the attributes are defined.
values: All values passed at model initialization.
Raises:
AssertionError: If either the deprecated or replacement attribute
don't exist.
TypeError: If the deprecated attribute is a required attribute.
ValueError: If the deprecated attribute and replacement attribute
contain different values.
Returns:
Input values with potentially migrated values.
"""
previous_deprecation_warnings: Set[str] = getattr(
cls, PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE, set()
)
def _warn(message: str, attribute: str) -> None:
"""Logs and raises a warning for a deprecated attribute.
Args:
message: The warning message.
attribute: The name of the attribute.
"""
if attribute not in previous_deprecation_warnings:
logger.warning(message)
previous_deprecation_warnings.add(attribute)
warnings.warn(
message,
DeprecationWarning,
)
for attribute in attributes:
if isinstance(attribute, str):
deprecated_attribute = attribute
replacement_attribute = None
else:
deprecated_attribute, replacement_attribute = attribute
assert (
replacement_attribute in cls.__fields__
), f"Unable to find attribute {replacement_attribute}."
assert (
deprecated_attribute in cls.__fields__
), f"Unable to find attribute {deprecated_attribute}."
if cls.__fields__[deprecated_attribute].required:
raise TypeError(
f"Unable to deprecate attribute '{deprecated_attribute}' "
f"of class {cls.__name__}. In order to deprecate an "
"attribute, it needs to be a non-required attribute. "
"To do so, mark the attribute with an `Optional[...] type "
"annotation."
)
if values.get(deprecated_attribute, None) is None:
continue
if replacement_attribute is None:
_warn(
message=f"The attribute `{deprecated_attribute}` of class "
f"`{cls.__name__}` will be deprecated soon.",
attribute=deprecated_attribute,
)
continue
_warn(
message=f"The attribute `{deprecated_attribute}` of class "
f"`{cls.__name__}` will be deprecated soon. Use the "
f"attribute `{replacement_attribute}` instead.",
attribute=deprecated_attribute,
)
if values.get(replacement_attribute, None) is None:
logger.debug(
"Migrating value of deprecated attribute %s to "
"replacement attribute %s.",
deprecated_attribute,
replacement_attribute,
)
values[replacement_attribute] = values.pop(
deprecated_attribute
)
elif values[deprecated_attribute] != values[replacement_attribute]:
raise ValueError(
"Got different values for deprecated attribute "
f"{deprecated_attribute} and replacement "
f"attribute {replacement_attribute}."
)
else:
# Both values are identical, no need to do anything
pass
setattr(
cls,
PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE,
previous_deprecation_warnings,
)
return values
return _deprecation_validator
dict_utils
Util functions for dictionaries.
recursive_update(original, update)
Recursively updates a dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
original |
Dict[str, Any] |
The dictionary to update. |
required |
update |
Dict[str, Any] |
The dictionary containing the updated values. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The updated dictionary. |
Source code in zenml/utils/dict_utils.py
def recursive_update(
original: Dict[str, Any], update: Dict[str, Any]
) -> Dict[str, Any]:
"""Recursively updates a dictionary.
Args:
original: The dictionary to update.
update: The dictionary containing the updated values.
Returns:
The updated dictionary.
"""
for key, value in update.items():
if isinstance(value, Dict):
original_value = original.get(key, None) or {}
if isinstance(original_value, Dict):
original[key] = recursive_update(original_value, value)
else:
original[key] = value
else:
original[key] = value
return original
remove_none_values(dict_, recursive=False)
Removes all key-value pairs with None
value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dict_ |
Dict[str, Any] |
The dict from which the key-value pairs should be removed. |
required |
recursive |
bool |
If |
False |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The updated dictionary. |
Source code in zenml/utils/dict_utils.py
def remove_none_values(
dict_: Dict[str, Any], recursive: bool = False
) -> Dict[str, Any]:
"""Removes all key-value pairs with `None` value.
Args:
dict_: The dict from which the key-value pairs should be removed.
recursive: If `True`, will recursively remove `None` values in all
child dicts.
Returns:
The updated dictionary.
"""
def _maybe_recurse(value: Any) -> Any:
"""Calls `remove_none_values` recursively if required.
Args:
value: A dictionary value.
Returns:
The updated dictionary value.
"""
if recursive and isinstance(value, Dict):
return remove_none_values(value, recursive=True)
else:
return value
return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None}
docker_utils
Utility functions relating to Docker.
build_image(image_name, dockerfile, build_context_root=None, dockerignore=None, extra_files=(), **custom_build_options)
Builds a docker image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
The name to use for the built docker image. |
required |
dockerfile |
Union[str, List[str]] |
Path to a dockerfile or a list of strings representing the Dockerfile lines/commands. |
required |
build_context_root |
Optional[str] |
Optional path to a directory that will be sent to the Docker daemon as build context. If left empty, the Docker build context will be empty. |
None |
dockerignore |
Optional[str] |
Optional path to a dockerignore file. If no value is
given, the .dockerignore in the root of the build context will be
used if it exists. Otherwise, all files inside |
None |
extra_files |
Sequence[Tuple[str, str]] |
Additional files to include in the build context. The files should be passed as a tuple (filepath_inside_build_context, file_content) and will overwrite existing files in the build context if they share the same path. |
() |
**custom_build_options |
Any |
Additional options that will be passed unmodified to the Docker build call when building the image. You can use this to for example specify build args or a target stage. See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build for a full list of available options. |
{} |
Source code in zenml/utils/docker_utils.py
def build_image(
image_name: str,
dockerfile: Union[str, List[str]],
build_context_root: Optional[str] = None,
dockerignore: Optional[str] = None,
extra_files: Sequence[Tuple[str, str]] = (),
**custom_build_options: Any,
) -> None:
"""Builds a docker image.
Args:
image_name: The name to use for the built docker image.
dockerfile: Path to a dockerfile or a list of strings representing the
Dockerfile lines/commands.
build_context_root: Optional path to a directory that will be sent to
the Docker daemon as build context. If left empty, the Docker build
context will be empty.
dockerignore: Optional path to a dockerignore file. If no value is
given, the .dockerignore in the root of the build context will be
used if it exists. Otherwise, all files inside `build_context_root`
are included in the build context.
extra_files: Additional files to include in the build context. The
files should be passed as a tuple
(filepath_inside_build_context, file_content) and will overwrite
existing files in the build context if they share the same path.
**custom_build_options: Additional options that will be passed
unmodified to the Docker build call when building the image. You
can use this to for example specify build args or a target stage.
See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build
for a full list of available options.
"""
if isinstance(dockerfile, str):
dockerfile_contents = io_utils.read_file_contents_as_string(dockerfile)
logger.info("Using Dockerfile `%s`.", os.path.abspath(dockerfile))
else:
dockerfile_contents = "\n".join(dockerfile)
build_context = _create_custom_build_context(
dockerfile_contents=dockerfile_contents,
build_context_root=build_context_root,
dockerignore=dockerignore,
extra_files=extra_files,
)
build_options = {
"rm": False, # don't remove intermediate containers to improve caching
"pull": True, # always pull parent images
**custom_build_options,
}
logger.info("Building Docker image `%s`.", image_name)
logger.debug("Docker build options: %s", build_options)
logger.info("Building the image might take a while...")
docker_client = DockerClient.from_env()
# We use the client api directly here, so we can stream the logs
output_stream = docker_client.images.client.api.build(
fileobj=build_context,
custom_context=True,
tag=image_name,
**build_options,
)
_process_stream(output_stream)
logger.info("Finished building Docker image `%s`.", image_name)
check_docker()
Checks if Docker is installed and running.
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/docker_utils.py
def check_docker() -> bool:
"""Checks if Docker is installed and running.
Returns:
`True` if Docker is installed, `False` otherwise.
"""
# Try to ping Docker, to see if it's running
try:
docker_client = DockerClient.from_env()
docker_client.ping()
return True
except Exception:
logger.debug("Docker is not running.", exc_info=True)
return False
get_image_digest(image_name)
Gets the digest of an image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the image to get the digest for. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
Returns the repo digest for the given image if there exists exactly one.
If there are zero or multiple repo digests, returns |
Source code in zenml/utils/docker_utils.py
def get_image_digest(image_name: str) -> Optional[str]:
"""Gets the digest of an image.
Args:
image_name: Name of the image to get the digest for.
Returns:
Returns the repo digest for the given image if there exists exactly one.
If there are zero or multiple repo digests, returns `None`.
"""
docker_client = DockerClient.from_env()
image = docker_client.images.get(image_name)
repo_digests = image.attrs["RepoDigests"]
if len(repo_digests) == 1:
return cast(str, repo_digests[0])
else:
logger.debug(
"Found zero or more repo digests for docker image '%s': %s",
image_name,
repo_digests,
)
return None
is_local_image(image_name)
Returns whether an image was pulled from a registry or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the image to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/docker_utils.py
def is_local_image(image_name: str) -> bool:
"""Returns whether an image was pulled from a registry or not.
Args:
image_name: Name of the image to check.
Returns:
`True` if the image was pulled from a registry, `False` otherwise.
"""
docker_client = DockerClient.from_env()
images = docker_client.images.list(name=image_name)
if images:
# An image with this name is available locally -> now check whether it
# was pulled from a repo or built locally (in which case the repo
# digest is empty)
return get_image_digest(image_name) is None
else:
# no image with this name found locally
return False
push_image(image_name)
Pushes an image to a container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
The full name (including a tag) of the image to push. |
required |
Returns:
Type | Description |
---|---|
str |
The Docker repository digest of the pushed image. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If fetching the repository digest of the image failed. |
Source code in zenml/utils/docker_utils.py
def push_image(image_name: str) -> str:
"""Pushes an image to a container registry.
Args:
image_name: The full name (including a tag) of the image to push.
Returns:
The Docker repository digest of the pushed image.
Raises:
RuntimeError: If fetching the repository digest of the image failed.
"""
logger.info("Pushing Docker image `%s`.", image_name)
docker_client = DockerClient.from_env()
output_stream = docker_client.images.push(image_name, stream=True)
aux_info = _process_stream(output_stream)
logger.info("Finished pushing Docker image.")
image_name_without_tag, _ = image_name.rsplit(":", maxsplit=1)
for info in reversed(aux_info):
try:
repo_digest = info["Digest"]
return f"{image_name_without_tag}@{repo_digest}"
except KeyError:
pass
else:
raise RuntimeError(
f"Unable to find repo digest after pushing image {image_name}."
)
tag_image(image_name, target)
Tags an image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
The name of the image to tag. |
required |
target |
str |
The full target name including a tag. |
required |
Source code in zenml/utils/docker_utils.py
def tag_image(image_name: str, target: str) -> None:
"""Tags an image.
Args:
image_name: The name of the image to tag.
target: The full target name including a tag.
"""
docker_client = DockerClient.from_env()
image = docker_client.images.get(image_name)
image.tag(target)
enum_utils
Util functions for enums.
StrEnum (str, Enum)
Base enum type for string enum values.
Source code in zenml/utils/enum_utils.py
class StrEnum(str, Enum):
"""Base enum type for string enum values."""
def __str__(self) -> str:
"""Returns the enum string value.
Returns:
The enum string value.
"""
return self.value # type: ignore
@classmethod
def names(cls) -> List[str]:
"""Get all enum names as a list of strings.
Returns:
A list of all enum names.
"""
return [c.name for c in cls]
@classmethod
def values(cls) -> List[str]:
"""Get all enum values as a list of strings.
Returns:
A list of all enum values.
"""
return [c.value for c in cls]
filesync_model
Filesync utils for ZenML.
FileSyncModel (BaseModel)
pydantic-model
Pydantic model synchronized with a configuration file.
Use this class as a base Pydantic model that is automatically synchronized with a configuration file on disk.
This class overrides the setattr and getattr magic methods to ensure that the FileSyncModel instance acts as an in-memory cache of the information stored in the associated configuration file.
Source code in zenml/utils/filesync_model.py
class FileSyncModel(BaseModel):
"""Pydantic model synchronized with a configuration file.
Use this class as a base Pydantic model that is automatically synchronized
with a configuration file on disk.
This class overrides the __setattr__ and __getattr__ magic methods to
ensure that the FileSyncModel instance acts as an in-memory cache of the
information stored in the associated configuration file.
"""
_config_file: str
_config_file_timestamp: Optional[float]
def __init__(self, config_file: str, **kwargs: Any) -> None:
"""Create a FileSyncModel instance synchronized with a configuration file on disk.
Args:
config_file: configuration file path. If the file exists, the model
will be initialized with the values from the file.
**kwargs: additional keyword arguments to pass to the Pydantic model
constructor. If supplied, these values will override those
loaded from the configuration file.
"""
config_dict = {}
if fileio.exists(config_file):
config_dict = yaml_utils.read_yaml(config_file)
self._config_file = config_file
self._config_file_timestamp = None
config_dict.update(kwargs)
super(FileSyncModel, self).__init__(**config_dict)
# write the configuration file to disk, to reflect new attributes
# and schema changes
self.write_config()
def __setattr__(self, key: str, value: Any) -> None:
"""Sets an attribute on the model and persists it in the configuration file.
Args:
key: attribute name.
value: attribute value.
"""
super(FileSyncModel, self).__setattr__(key, value)
if key.startswith("_"):
return
self.write_config()
def __getattribute__(self, key: str) -> Any:
"""Gets an attribute value for a specific key.
Args:
key: attribute name.
Returns:
attribute value.
"""
if not key.startswith("_") and key in self.__dict__:
self.load_config()
return super(FileSyncModel, self).__getattribute__(key)
def write_config(self) -> None:
"""Writes the model to the configuration file."""
config_dict = json.loads(self.json())
yaml_utils.write_yaml(self._config_file, config_dict)
self._config_file_timestamp = os.path.getmtime(self._config_file)
def load_config(self) -> None:
"""Loads the model from the configuration file on disk."""
if not fileio.exists(self._config_file):
return
# don't reload the configuration if the file hasn't
# been updated since the last load
file_timestamp = os.path.getmtime(self._config_file)
if file_timestamp == self._config_file_timestamp:
return
if self._config_file_timestamp is not None:
logger.info(f"Reloading configuration file {self._config_file}")
# refresh the model from the configuration file values
config_dict = yaml_utils.read_yaml(self._config_file)
for key, value in config_dict.items():
super(FileSyncModel, self).__setattr__(key, value)
self._config_file_timestamp = file_timestamp
class Config:
"""Pydantic configuration class."""
# all attributes with leading underscore are private and therefore
# are mutable and not included in serialization
underscore_attrs_are_private = True
Config
Pydantic configuration class.
Source code in zenml/utils/filesync_model.py
class Config:
"""Pydantic configuration class."""
# all attributes with leading underscore are private and therefore
# are mutable and not included in serialization
underscore_attrs_are_private = True
__getattribute__(self, key)
special
Gets an attribute value for a specific key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
attribute name. |
required |
Returns:
Type | Description |
---|---|
Any |
attribute value. |
Source code in zenml/utils/filesync_model.py
def __getattribute__(self, key: str) -> Any:
"""Gets an attribute value for a specific key.
Args:
key: attribute name.
Returns:
attribute value.
"""
if not key.startswith("_") and key in self.__dict__:
self.load_config()
return super(FileSyncModel, self).__getattribute__(key)
__init__(self, config_file, **kwargs)
special
Create a FileSyncModel instance synchronized with a configuration file on disk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_file |
str |
configuration file path. If the file exists, the model will be initialized with the values from the file. |
required |
**kwargs |
Any |
additional keyword arguments to pass to the Pydantic model constructor. If supplied, these values will override those loaded from the configuration file. |
{} |
Source code in zenml/utils/filesync_model.py
def __init__(self, config_file: str, **kwargs: Any) -> None:
"""Create a FileSyncModel instance synchronized with a configuration file on disk.
Args:
config_file: configuration file path. If the file exists, the model
will be initialized with the values from the file.
**kwargs: additional keyword arguments to pass to the Pydantic model
constructor. If supplied, these values will override those
loaded from the configuration file.
"""
config_dict = {}
if fileio.exists(config_file):
config_dict = yaml_utils.read_yaml(config_file)
self._config_file = config_file
self._config_file_timestamp = None
config_dict.update(kwargs)
super(FileSyncModel, self).__init__(**config_dict)
# write the configuration file to disk, to reflect new attributes
# and schema changes
self.write_config()
__setattr__(self, key, value)
special
Sets an attribute on the model and persists it in the configuration file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
attribute name. |
required |
value |
Any |
attribute value. |
required |
Source code in zenml/utils/filesync_model.py
def __setattr__(self, key: str, value: Any) -> None:
"""Sets an attribute on the model and persists it in the configuration file.
Args:
key: attribute name.
value: attribute value.
"""
super(FileSyncModel, self).__setattr__(key, value)
if key.startswith("_"):
return
self.write_config()
load_config(self)
Loads the model from the configuration file on disk.
Source code in zenml/utils/filesync_model.py
def load_config(self) -> None:
"""Loads the model from the configuration file on disk."""
if not fileio.exists(self._config_file):
return
# don't reload the configuration if the file hasn't
# been updated since the last load
file_timestamp = os.path.getmtime(self._config_file)
if file_timestamp == self._config_file_timestamp:
return
if self._config_file_timestamp is not None:
logger.info(f"Reloading configuration file {self._config_file}")
# refresh the model from the configuration file values
config_dict = yaml_utils.read_yaml(self._config_file)
for key, value in config_dict.items():
super(FileSyncModel, self).__setattr__(key, value)
self._config_file_timestamp = file_timestamp
write_config(self)
Writes the model to the configuration file.
Source code in zenml/utils/filesync_model.py
def write_config(self) -> None:
"""Writes the model to the configuration file."""
config_dict = json.loads(self.json())
yaml_utils.write_yaml(self._config_file, config_dict)
self._config_file_timestamp = os.path.getmtime(self._config_file)
io_utils
Various utility functions for the io module.
copy_dir(source_dir, destination_dir, overwrite=False)
Copies dir from source to destination.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source_dir |
str |
Path to copy from. |
required |
destination_dir |
str |
Path to copy to. |
required |
overwrite |
bool |
Boolean. If false, function throws an error before overwrite. |
False |
Source code in zenml/utils/io_utils.py
def copy_dir(
source_dir: str, destination_dir: str, overwrite: bool = False
) -> None:
"""Copies dir from source to destination.
Args:
source_dir: Path to copy from.
destination_dir: Path to copy to.
overwrite: Boolean. If false, function throws an error before overwrite.
"""
for source_file in listdir(source_dir):
source_path = os.path.join(source_dir, convert_to_str(source_file))
destination_path = os.path.join(
destination_dir, convert_to_str(source_file)
)
if isdir(source_path):
if source_path == destination_dir:
# if the destination is a subdirectory of the source, we skip
# copying it to avoid an infinite loop.
continue
copy_dir(source_path, destination_path, overwrite)
else:
create_dir_recursive_if_not_exists(
os.path.dirname(destination_path)
)
copy(str(source_path), str(destination_path), overwrite)
create_dir_if_not_exists(dir_path)
Creates directory if it does not exist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
Local path in filesystem. |
required |
Source code in zenml/utils/io_utils.py
def create_dir_if_not_exists(dir_path: str) -> None:
"""Creates directory if it does not exist.
Args:
dir_path: Local path in filesystem.
"""
if not isdir(dir_path):
mkdir(dir_path)
create_dir_recursive_if_not_exists(dir_path)
Creates directory recursively if it does not exist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
Local path in filesystem. |
required |
Source code in zenml/utils/io_utils.py
def create_dir_recursive_if_not_exists(dir_path: str) -> None:
"""Creates directory recursively if it does not exist.
Args:
dir_path: Local path in filesystem.
"""
if not isdir(dir_path):
makedirs(dir_path)
create_file_if_not_exists(file_path, file_contents='{}')
Creates file if it does not exist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Local path in filesystem. |
required |
file_contents |
str |
Contents of file. |
'{}' |
Source code in zenml/utils/io_utils.py
def create_file_if_not_exists(
file_path: str, file_contents: str = "{}"
) -> None:
"""Creates file if it does not exist.
Args:
file_path: Local path in filesystem.
file_contents: Contents of file.
"""
full_path = Path(file_path)
if not exists(file_path):
create_dir_recursive_if_not_exists(str(full_path.parent))
with open(str(full_path), "w") as f:
f.write(file_contents)
find_files(dir_path, pattern)
Find files in a directory that match pattern.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
PathType |
The path to directory. |
required |
pattern |
str |
pattern like *.png. |
required |
Yields:
Type | Description |
---|---|
Iterable[str] |
All matching filenames in the directory. |
Source code in zenml/utils/io_utils.py
def find_files(dir_path: "PathType", pattern: str) -> Iterable[str]:
"""Find files in a directory that match pattern.
Args:
dir_path: The path to directory.
pattern: pattern like *.png.
Yields:
All matching filenames in the directory.
"""
for root, _, files in walk(dir_path):
for basename in files:
if fnmatch.fnmatch(convert_to_str(basename), pattern):
filename = os.path.join(
convert_to_str(root), convert_to_str(basename)
)
yield filename
get_global_config_directory()
Gets the global config directory for ZenML.
Returns:
Type | Description |
---|---|
str |
The global config directory for ZenML. |
Source code in zenml/utils/io_utils.py
def get_global_config_directory() -> str:
"""Gets the global config directory for ZenML.
Returns:
The global config directory for ZenML.
"""
env_var_path = os.getenv(ENV_ZENML_CONFIG_PATH)
if env_var_path:
return str(Path(env_var_path).resolve())
return click.get_app_dir(APP_NAME)
get_grandparent(dir_path)
Get grandparent of dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
The path to directory. |
required |
Returns:
Type | Description |
---|---|
str |
The input paths parents parent. |
Exceptions:
Type | Description |
---|---|
ValueError |
If dir_path does not exist. |
Source code in zenml/utils/io_utils.py
def get_grandparent(dir_path: str) -> str:
"""Get grandparent of dir.
Args:
dir_path: The path to directory.
Returns:
The input paths parents parent.
Raises:
ValueError: If dir_path does not exist.
"""
if not os.path.exists(dir_path):
raise ValueError(f"Path '{dir_path}' does not exist.")
return Path(dir_path).parent.parent.stem
get_parent(dir_path)
Get parent of dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
The path to directory. |
required |
Returns:
Type | Description |
---|---|
str |
Parent (stem) of the dir as a string. |
Exceptions:
Type | Description |
---|---|
ValueError |
If dir_path does not exist. |
Source code in zenml/utils/io_utils.py
def get_parent(dir_path: str) -> str:
"""Get parent of dir.
Args:
dir_path: The path to directory.
Returns:
Parent (stem) of the dir as a string.
Raises:
ValueError: If dir_path does not exist.
"""
if not os.path.exists(dir_path):
raise ValueError(f"Path '{dir_path}' does not exist.")
return Path(dir_path).parent.stem
is_remote(path)
Returns True if path exists remotely.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Any path as a string. |
required |
Returns:
Type | Description |
---|---|
bool |
True if remote path, else False. |
Source code in zenml/utils/io_utils.py
def is_remote(path: str) -> bool:
"""Returns True if path exists remotely.
Args:
path: Any path as a string.
Returns:
True if remote path, else False.
"""
return any(path.startswith(prefix) for prefix in REMOTE_FS_PREFIX)
is_root(path)
Returns true if path has no parent in local filesystem.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Local path in filesystem. |
required |
Returns:
Type | Description |
---|---|
bool |
True if root, else False. |
Source code in zenml/utils/io_utils.py
def is_root(path: str) -> bool:
"""Returns true if path has no parent in local filesystem.
Args:
path: Local path in filesystem.
Returns:
True if root, else False.
"""
return Path(path).parent == Path(path)
move(source, destination, overwrite=False)
Moves dir or file from source to destination. Can be used to rename.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
Local path to copy from. |
required |
destination |
str |
Local path to copy to. |
required |
overwrite |
bool |
boolean, if false, then throws an error before overwrite. |
False |
Source code in zenml/utils/io_utils.py
def move(source: str, destination: str, overwrite: bool = False) -> None:
"""Moves dir or file from source to destination. Can be used to rename.
Args:
source: Local path to copy from.
destination: Local path to copy to.
overwrite: boolean, if false, then throws an error before overwrite.
"""
rename(source, destination, overwrite)
read_file_contents_as_string(file_path)
Reads contents of file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to file. |
required |
Returns:
Type | Description |
---|---|
str |
Contents of file. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If file does not exist. |
Source code in zenml/utils/io_utils.py
def read_file_contents_as_string(file_path: str) -> str:
"""Reads contents of file.
Args:
file_path: Path to file.
Returns:
Contents of file.
Raises:
FileNotFoundError: If file does not exist.
"""
if not exists(file_path):
raise FileNotFoundError(f"{file_path} does not exist!")
with open(file_path) as f:
return f.read() # type: ignore[no-any-return]
resolve_relative_path(path)
Takes relative path and resolves it absolutely.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Local path in filesystem. |
required |
Returns:
Type | Description |
---|---|
str |
Resolved path. |
Source code in zenml/utils/io_utils.py
def resolve_relative_path(path: str) -> str:
"""Takes relative path and resolves it absolutely.
Args:
path: Local path in filesystem.
Returns:
Resolved path.
"""
if is_remote(path):
return path
return str(Path(path).resolve())
write_file_contents_as_string(file_path, content)
Writes contents of file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to file. |
required |
content |
str |
Contents of file. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If content is not of type str. |
Source code in zenml/utils/io_utils.py
def write_file_contents_as_string(file_path: str, content: str) -> None:
"""Writes contents of file.
Args:
file_path: Path to file.
content: Contents of file.
Raises:
ValueError: If content is not of type str.
"""
if not isinstance(content, str):
raise ValueError(f"Content must be of type str, got {type(content)}")
with open(file_path, "w") as f:
f.write(content)
materializer_utils
Util functions for models and materializers.
load_artifact(artifact)
Load the given artifact into memory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactResponseModel |
The artifact to load. |
required |
Returns:
Type | Description |
---|---|
Any |
The artifact loaded into memory. |
Source code in zenml/utils/materializer_utils.py
def load_artifact(artifact: "ArtifactResponseModel") -> Any:
"""Load the given artifact into memory.
Args:
artifact: The artifact to load.
Returns:
The artifact loaded into memory.
"""
return _load_artifact(
materializer=artifact.materializer,
data_type=artifact.data_type,
uri=artifact.uri,
)
load_model_from_metadata(model_uri)
Load a zenml model artifact from a json file.
This function is used to load information from a Yaml file that was created by the save_model_metadata function. The information in the Yaml file is used to load the model into memory in the inference environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the artifact to extract the metadata from. |
required |
Returns:
Type | Description |
---|---|
Any |
The ML model object loaded into memory. |
Source code in zenml/utils/materializer_utils.py
def load_model_from_metadata(model_uri: str) -> Any:
"""Load a zenml model artifact from a json file.
This function is used to load information from a Yaml file that was created
by the save_model_metadata function. The information in the Yaml file is
used to load the model into memory in the inference environment.
Args:
model_uri: the artifact to extract the metadata from.
Returns:
The ML model object loaded into memory.
"""
# Load the model from its metadata
with fileio.open(
os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
) as f:
metadata = read_yaml(f.name)
data_type = metadata[METADATA_DATATYPE]
materializer = metadata[METADATA_MATERIALIZER]
model = _load_artifact(
materializer=materializer, data_type=data_type, uri=model_uri
)
# Switch to eval mode if the model is a torch model
try:
import torch.nn as nn
if isinstance(model, nn.Module):
model.eval()
except ImportError:
pass
return model
save_model_metadata(model_artifact)
Save a zenml model artifact metadata to a YAML file.
This function is used to extract and save information from a zenml model artifact such as the model type and materializer. The extracted information will be the key to loading the model into memory in the inference environment.
datatype: the model type. This is the path to the model class. materializer: the materializer class. This is the path to the materializer class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_artifact |
ArtifactResponseModel |
the artifact to extract the metadata from. |
required |
Returns:
Type | Description |
---|---|
str |
The path to the temporary file where the model metadata is saved |
Source code in zenml/utils/materializer_utils.py
def save_model_metadata(model_artifact: "ArtifactResponseModel") -> str:
"""Save a zenml model artifact metadata to a YAML file.
This function is used to extract and save information from a zenml model artifact
such as the model type and materializer. The extracted information will be
the key to loading the model into memory in the inference environment.
datatype: the model type. This is the path to the model class.
materializer: the materializer class. This is the path to the materializer class.
Args:
model_artifact: the artifact to extract the metadata from.
Returns:
The path to the temporary file where the model metadata is saved
"""
metadata = dict()
metadata[METADATA_DATATYPE] = model_artifact.data_type
metadata[METADATA_MATERIALIZER] = model_artifact.materializer
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
write_yaml(f.name, metadata)
return f.name
networking_utils
Utility functions for networking.
find_available_port()
Finds a local random unoccupied TCP port.
Returns:
Type | Description |
---|---|
int |
A random unoccupied TCP port. |
Source code in zenml/utils/networking_utils.py
def find_available_port() -> int:
"""Finds a local random unoccupied TCP port.
Returns:
A random unoccupied TCP port.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
_, port = s.getsockname()
return cast(int, port)
port_available(port, address='127.0.0.1')
Checks if a local port is available.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
port |
int |
TCP port number |
required |
address |
str |
IP address on the local machine |
'127.0.0.1' |
Returns:
Type | Description |
---|---|
bool |
True if the port is available, otherwise False |
Source code in zenml/utils/networking_utils.py
def port_available(port: int, address: str = "127.0.0.1") -> bool:
"""Checks if a local port is available.
Args:
port: TCP port number
address: IP address on the local machine
Returns:
True if the port is available, otherwise False
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if sys.platform != "win32":
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
else:
# The SO_REUSEPORT socket option is not supported on Windows.
# This if clause exists just for mypy to not complain about
# missing code paths.
pass
s.bind((address, port))
except socket.error as e:
logger.debug("Port %d unavailable on %s: %s", port, address, e)
return False
return True
port_is_open(hostname, port)
Check if a TCP port is open on a remote host.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hostname |
str |
hostname of the remote machine |
required |
port |
int |
TCP port number |
required |
Returns:
Type | Description |
---|---|
bool |
True if the port is open, False otherwise |
Source code in zenml/utils/networking_utils.py
def port_is_open(hostname: str, port: int) -> bool:
"""Check if a TCP port is open on a remote host.
Args:
hostname: hostname of the remote machine
port: TCP port number
Returns:
True if the port is open, False otherwise
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex((hostname, port))
return result == 0
except socket.error as e:
logger.debug(
f"Error checking TCP port {port} on host {hostname}: {str(e)}"
)
return False
replace_internal_hostname_with_localhost(hostname)
Replaces an internal Docker or K3D hostname with localhost.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(host.docker.internal
) and K3D (host.k3d.internal
) that can be used to
access host services from within the containers.
Use this method to replace one of these special hostnames with localhost if used outside a container or in a container where special hostnames are not available.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hostname |
str |
The hostname to replace. |
required |
Returns:
Type | Description |
---|---|
str |
The original or replaced hostname. |
Source code in zenml/utils/networking_utils.py
def replace_internal_hostname_with_localhost(hostname: str) -> str:
"""Replaces an internal Docker or K3D hostname with localhost.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
access host services from within the containers.
Use this method to replace one of these special hostnames with localhost
if used outside a container or in a container where special hostnames are
not available.
Args:
hostname: The hostname to replace.
Returns:
The original or replaced hostname.
"""
if hostname not in ("host.docker.internal", "host.k3d.internal"):
return hostname
if Environment.in_container():
# Try to resolve one of the special hostnames to see if it is available
# inside the container and use that if it is.
for internal_hostname in (
"host.docker.internal",
"host.k3d.internal",
):
try:
socket.gethostbyname(internal_hostname)
if internal_hostname != hostname:
logger.debug(
f"Replacing internal hostname {hostname} with "
f"{internal_hostname}"
)
return internal_hostname
except socket.gaierror:
continue
logger.debug(f"Replacing internal hostname {hostname} with localhost.")
return "127.0.0.1"
replace_localhost_with_internal_hostname(url)
Replaces the localhost with an internal Docker or K3D hostname in a given URL.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(host.docker.internal
) and K3D (host.k3d.internal
) that can be used to
access host services from within the containers.
Use this method to attempt to replace localhost
in a URL with one of these
special hostnames, if they are available inside a container.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL to update. |
required |
Returns:
Type | Description |
---|---|
str |
The updated URL. |
Source code in zenml/utils/networking_utils.py
def replace_localhost_with_internal_hostname(url: str) -> str:
"""Replaces the localhost with an internal Docker or K3D hostname in a given URL.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
access host services from within the containers.
Use this method to attempt to replace `localhost` in a URL with one of these
special hostnames, if they are available inside a container.
Args:
url: The URL to update.
Returns:
The updated URL.
"""
if not Environment.in_container():
return url
parsed_url = urlparse(url)
if parsed_url.hostname in ("localhost", "127.0.0.1"):
for internal_hostname in (
"host.docker.internal",
"host.k3d.internal",
):
try:
socket.gethostbyname(internal_hostname)
parsed_url = parsed_url._replace(
netloc=parsed_url.netloc.replace(
parsed_url.hostname,
internal_hostname,
)
)
logger.debug(
f"Replacing localhost with {internal_hostname} in URL: "
f"{url}"
)
return parsed_url.geturl()
except socket.gaierror:
continue
return url
scan_for_available_port(start=8000, stop=65535)
Scan the local network for an available port in the given range.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start |
int |
the beginning of the port range value to scan |
8000 |
stop |
int |
the (inclusive) end of the port range value to scan |
65535 |
Returns:
Type | Description |
---|---|
Optional[int] |
The first available port in the given range, or None if no available port is found. |
Source code in zenml/utils/networking_utils.py
def scan_for_available_port(
start: int = SCAN_PORT_RANGE[0], stop: int = SCAN_PORT_RANGE[1]
) -> Optional[int]:
"""Scan the local network for an available port in the given range.
Args:
start: the beginning of the port range value to scan
stop: the (inclusive) end of the port range value to scan
Returns:
The first available port in the given range, or None if no available
port is found.
"""
for port in range(start, stop + 1):
if port_available(port):
return port
logger.debug(
"No free TCP ports found in the range %d - %d",
start,
stop,
)
return None
pipeline_docker_image_builder
Implementation of Docker image builds to run ZenML pipelines.
PipelineDockerImageBuilder
Builds Docker images to run a ZenML pipeline.
Source code in zenml/utils/pipeline_docker_image_builder.py
class PipelineDockerImageBuilder:
"""Builds Docker images to run a ZenML pipeline."""
def build_docker_image(
self,
docker_settings: "DockerSettings",
tag: str,
stack: "Stack",
entrypoint: Optional[str] = None,
extra_files: Optional[Dict[str, str]] = None,
) -> str:
"""Builds (and optionally pushes) a Docker image to run a pipeline.
Use the image name returned by this method whenever you need to uniquely
reference the pushed image in order to pull or run it.
Args:
docker_settings: The settings for the image build.
tag: The tag to use for the image.
stack: The stack on which the pipeline will be deployed.
entrypoint: Entrypoint to use for the final image. If left empty,
no entrypoint will be included in the image.
extra_files: Extra files to add to the build context. Keys are the
path inside the build context, values are either the file
content or a file path.
Returns:
The Docker image repo digest or local name, depending on whether
the image was pushed or is just stored locally.
Raises:
RuntimeError: If the stack does not contain an image builder.
ValueError: If no Dockerfile and/or custom parent image is
specified and the Docker configuration doesn't require an
image build.
"""
if docker_settings.skip_build:
assert (
docker_settings.parent_image
) # checked via validator already
# Should we tag this here and push it to the container registry of
# the stack to make sure it's always accessible when running the
# pipeline?
return docker_settings.parent_image
image_builder = stack.image_builder
if not image_builder:
raise RuntimeError(
"Unable to build Docker images without an image builder in the "
f"stack `{stack.name}`."
)
container_registry = stack.container_registry
build_context_class = image_builder.build_context_class
target_image_name = self._get_target_image_name(
docker_settings=docker_settings,
tag=tag,
container_registry=container_registry,
)
requires_zenml_build = any(
[
docker_settings.requirements,
docker_settings.required_integrations,
docker_settings.replicate_local_python_environment,
docker_settings.install_stack_requirements,
docker_settings.apt_packages,
docker_settings.environment,
docker_settings.copy_files,
docker_settings.copy_global_config,
entrypoint,
extra_files,
]
)
# Fallback to the value defined on the stack component if the
# pipeline configuration doesn't have a configured value
parent_image = (
docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
)
if docker_settings.dockerfile:
if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
logger.warning(
"You've specified both a Dockerfile and a custom parent "
"image, ignoring the parent image."
)
push = (
not image_builder.config.is_local or not requires_zenml_build
)
if requires_zenml_build:
# We will build an additional image on top of this one later
# to include user files and/or install requirements. The image
# we build now will be used as the parent for the next build.
user_image_name = (
f"{docker_settings.target_repository}:"
f"{tag}-intermediate-build"
)
if push and container_registry:
user_image_name = (
f"{container_registry.config.uri}/{user_image_name}"
)
parent_image = user_image_name
else:
# The image we'll build from the custom Dockerfile will be
# used directly, so we tag it with the requested target name.
user_image_name = target_image_name
build_context = build_context_class(
root=docker_settings.build_context_root
)
build_context.add_file(
source=docker_settings.dockerfile, destination="Dockerfile"
)
logger.info("Building Docker image `%s`.", user_image_name)
image_name_or_digest = image_builder.build(
image_name=user_image_name,
build_context=build_context,
docker_build_options=docker_settings.build_options,
container_registry=container_registry if push else None,
)
elif not requires_zenml_build:
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
raise ValueError(
"Unable to run a ZenML pipeline with the given Docker "
"settings: No Dockerfile or custom parent image "
"specified and no files will be copied or requirements "
"installed."
)
else:
# The parent image will be used directly to run the pipeline and
# needs to be tagged/pushed
docker_utils.tag_image(parent_image, target=target_image_name)
if container_registry:
image_name_or_digest = container_registry.push_image(
target_image_name
)
else:
image_name_or_digest = target_image_name
if requires_zenml_build:
logger.info("Building Docker image `%s`.", target_image_name)
# Leave the build context empty if we don't want to copy any files
build_context_root = (
source_utils.get_source_root_path()
if docker_settings.copy_files
else None
)
build_context = build_context_class(
root=build_context_root,
dockerignore_file=docker_settings.dockerignore,
)
requirements_file_names = self._add_requirements_files(
docker_settings=docker_settings,
build_context=build_context,
stack=stack,
)
apt_packages = docker_settings.apt_packages
if docker_settings.install_stack_requirements:
apt_packages += stack.apt_packages
if apt_packages:
logger.info(
"Including apt packages: %s",
", ".join(f"`{p}`" for p in apt_packages),
)
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
# The default parent image is static and doesn't require a pull
# each time
pull_parent_image = False
elif docker_settings.dockerfile and not container_registry:
# We built a custom parent image and there was no container
# registry in the stack to push to, this is a local image
pull_parent_image = False
else:
# If the image is local, we don't need to pull it. Otherwise
# we play it safe and always pull in case the user pushed a new
# image for the given name and tag
pull_parent_image = not docker_utils.is_local_image(
parent_image
)
build_options = {"pull": pull_parent_image, "rm": False}
dockerfile = self._generate_zenml_pipeline_dockerfile(
parent_image=parent_image,
docker_settings=docker_settings,
requirements_files=requirements_file_names,
apt_packages=apt_packages,
entrypoint=entrypoint,
)
build_context.add_file(destination="Dockerfile", source=dockerfile)
if docker_settings.copy_global_config:
with tempfile.TemporaryDirectory() as tmpdir:
GlobalConfiguration().copy_configuration(
tmpdir,
load_config_path=PurePosixPath(
DOCKER_IMAGE_ZENML_CONFIG_PATH
),
)
build_context.add_directory(
source=tmpdir,
destination=DOCKER_IMAGE_ZENML_CONFIG_DIR,
)
if extra_files:
for destination, source in extra_files.items():
build_context.add_file(
destination=destination, source=source
)
image_name_or_digest = image_builder.build(
image_name=target_image_name,
build_context=build_context,
docker_build_options=build_options,
container_registry=container_registry,
)
return image_name_or_digest
@staticmethod
def _get_target_image_name(
docker_settings: "DockerSettings",
tag: str,
container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
"""Returns the target image name.
If a container registry is given, the image name will include the
registry URI
Args:
docker_settings: The settings for the image build.
tag: The tag to use for the image.
container_registry: Optional container registry to which this
image will be pushed.
Returns:
The docker image name.
"""
target_image_name = f"{docker_settings.target_repository}:{tag}"
if container_registry:
target_image_name = (
f"{container_registry.config.uri}/{target_image_name}"
)
return target_image_name
@classmethod
def _add_requirements_files(
cls,
docker_settings: DockerSettings,
build_context: "BuildContext",
stack: "Stack",
) -> List[str]:
"""Adds requirements files to the build context.
Args:
docker_settings: Docker settings that specifies which
requirements to install.
build_context: Build context to add the requirements files to.
stack: The stack on which the pipeline will run.
Returns:
Name of the requirements files in the build context.
The files will be in the following order:
- Packages installed in the local Python environment
- User-defined requirements
- Requirements defined by user-defined and/or stack integrations
"""
requirements_file_names = []
requirements_files = cls._gather_requirements_files(
docker_settings=docker_settings, stack=stack
)
for filename, file_content in requirements_files:
build_context.add_file(source=file_content, destination=filename)
requirements_file_names.append(filename)
return requirements_file_names
@staticmethod
def _gather_requirements_files(
docker_settings: DockerSettings, stack: "Stack", log: bool = True
) -> List[Tuple[str, str]]:
"""Gathers and/or generates pip requirements files.
Args:
docker_settings: Docker settings that specifies which
requirements to install.
stack: The stack on which the pipeline will run.
log: If `True`, will log the requirements.
Raises:
RuntimeError: If the command to export the local python packages
failed.
Returns:
List of tuples (filename, file_content) of all requirements files.
The files will be in the following order:
- Packages installed in the local Python environment
- User-defined requirements
- Requirements defined by user-defined and/or stack integrations
"""
requirements_files = []
# Generate requirements file for the local environment if configured
if docker_settings.replicate_local_python_environment:
if isinstance(
docker_settings.replicate_local_python_environment,
PythonEnvironmentExportMethod,
):
command = (
docker_settings.replicate_local_python_environment.command
)
else:
command = " ".join(
docker_settings.replicate_local_python_environment
)
try:
local_requirements = subprocess.check_output(
command, shell=True
).decode()
except subprocess.CalledProcessError as e:
raise RuntimeError(
"Unable to export local python packages."
) from e
requirements_files.append(
(".zenml_local_requirements", local_requirements)
)
if log:
logger.info(
"- Including python packages from local environment"
)
# Generate/Read requirements file for user-defined requirements
if isinstance(docker_settings.requirements, str):
user_requirements = io_utils.read_file_contents_as_string(
docker_settings.requirements
)
if log:
logger.info(
"- Including user-defined requirements from file `%s`",
os.path.abspath(docker_settings.requirements),
)
elif isinstance(docker_settings.requirements, List):
user_requirements = "\n".join(docker_settings.requirements)
if log:
logger.info(
"- Including user-defined requirements: %s",
", ".join(f"`{r}`" for r in docker_settings.requirements),
)
else:
user_requirements = None
if user_requirements:
requirements_files.append(
(".zenml_user_requirements", user_requirements)
)
# Generate requirements file for all required integrations
integration_requirements = set(
itertools.chain.from_iterable(
integration_registry.select_integration_requirements(
integration
)
for integration in docker_settings.required_integrations
)
)
if docker_settings.install_stack_requirements:
integration_requirements.update(stack.requirements())
if integration_requirements:
integration_requirements_list = sorted(integration_requirements)
integration_requirements_file = "\n".join(
integration_requirements_list
)
requirements_files.append(
(
".zenml_integration_requirements",
integration_requirements_file,
)
)
if log:
logger.info(
"- Including integration requirements: %s",
", ".join(f"`{r}`" for r in integration_requirements_list),
)
return requirements_files
@staticmethod
def _generate_zenml_pipeline_dockerfile(
parent_image: str,
docker_settings: DockerSettings,
requirements_files: Sequence[str] = (),
apt_packages: Sequence[str] = (),
entrypoint: Optional[str] = None,
) -> str:
"""Generates a Dockerfile.
Args:
parent_image: The image to use as parent for the Dockerfile.
docker_settings: Docker settings for this image build.
requirements_files: Paths of requirements files to install.
apt_packages: APT packages to install.
entrypoint: The default entrypoint command that gets executed when
running a container of an image created by this Dockerfile.
Returns:
The generated Dockerfile.
"""
lines = [f"FROM {parent_image}", f"WORKDIR {DOCKER_IMAGE_WORKDIR}"]
if apt_packages:
apt_packages = " ".join(f"'{p}'" for p in apt_packages)
lines.append(
"RUN apt-get update && apt-get install -y "
f"--no-install-recommends {apt_packages}"
)
for file in requirements_files:
lines.append(f"COPY {file} .")
lines.append(
f"RUN pip install --default-timeout=60 --no-cache-dir -r {file}"
)
lines.append(f"ENV {ENV_ZENML_ENABLE_REPO_INIT_WARNINGS}=False")
if docker_settings.copy_global_config:
lines.append(
f"ENV {ENV_ZENML_CONFIG_PATH}={DOCKER_IMAGE_ZENML_CONFIG_PATH}"
)
for key, value in docker_settings.environment.items():
lines.append(f"ENV {key.upper()}={value}")
if docker_settings.copy_files:
lines.append("COPY . .")
elif docker_settings.copy_global_config:
lines.append(f"COPY {DOCKER_IMAGE_ZENML_CONFIG_DIR} .")
lines.append("RUN chmod -R a+rw .")
if docker_settings.user:
lines.append(f"USER {docker_settings.user}")
lines.append(f"RUN chown -R {docker_settings.user} .")
if entrypoint:
lines.append(f"ENTRYPOINT {entrypoint}")
return "\n".join(lines)
build_docker_image(self, docker_settings, tag, stack, entrypoint=None, extra_files=None)
Builds (and optionally pushes) a Docker image to run a pipeline.
Use the image name returned by this method whenever you need to uniquely reference the pushed image in order to pull or run it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
docker_settings |
DockerSettings |
The settings for the image build. |
required |
tag |
str |
The tag to use for the image. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
entrypoint |
Optional[str] |
Entrypoint to use for the final image. If left empty, no entrypoint will be included in the image. |
None |
extra_files |
Optional[Dict[str, str]] |
Extra files to add to the build context. Keys are the path inside the build context, values are either the file content or a file path. |
None |
Returns:
Type | Description |
---|---|
str |
The Docker image repo digest or local name, depending on whether the image was pushed or is just stored locally. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the stack does not contain an image builder. |
ValueError |
If no Dockerfile and/or custom parent image is specified and the Docker configuration doesn't require an image build. |
Source code in zenml/utils/pipeline_docker_image_builder.py
def build_docker_image(
self,
docker_settings: "DockerSettings",
tag: str,
stack: "Stack",
entrypoint: Optional[str] = None,
extra_files: Optional[Dict[str, str]] = None,
) -> str:
"""Builds (and optionally pushes) a Docker image to run a pipeline.
Use the image name returned by this method whenever you need to uniquely
reference the pushed image in order to pull or run it.
Args:
docker_settings: The settings for the image build.
tag: The tag to use for the image.
stack: The stack on which the pipeline will be deployed.
entrypoint: Entrypoint to use for the final image. If left empty,
no entrypoint will be included in the image.
extra_files: Extra files to add to the build context. Keys are the
path inside the build context, values are either the file
content or a file path.
Returns:
The Docker image repo digest or local name, depending on whether
the image was pushed or is just stored locally.
Raises:
RuntimeError: If the stack does not contain an image builder.
ValueError: If no Dockerfile and/or custom parent image is
specified and the Docker configuration doesn't require an
image build.
"""
if docker_settings.skip_build:
assert (
docker_settings.parent_image
) # checked via validator already
# Should we tag this here and push it to the container registry of
# the stack to make sure it's always accessible when running the
# pipeline?
return docker_settings.parent_image
image_builder = stack.image_builder
if not image_builder:
raise RuntimeError(
"Unable to build Docker images without an image builder in the "
f"stack `{stack.name}`."
)
container_registry = stack.container_registry
build_context_class = image_builder.build_context_class
target_image_name = self._get_target_image_name(
docker_settings=docker_settings,
tag=tag,
container_registry=container_registry,
)
requires_zenml_build = any(
[
docker_settings.requirements,
docker_settings.required_integrations,
docker_settings.replicate_local_python_environment,
docker_settings.install_stack_requirements,
docker_settings.apt_packages,
docker_settings.environment,
docker_settings.copy_files,
docker_settings.copy_global_config,
entrypoint,
extra_files,
]
)
# Fallback to the value defined on the stack component if the
# pipeline configuration doesn't have a configured value
parent_image = (
docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
)
if docker_settings.dockerfile:
if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
logger.warning(
"You've specified both a Dockerfile and a custom parent "
"image, ignoring the parent image."
)
push = (
not image_builder.config.is_local or not requires_zenml_build
)
if requires_zenml_build:
# We will build an additional image on top of this one later
# to include user files and/or install requirements. The image
# we build now will be used as the parent for the next build.
user_image_name = (
f"{docker_settings.target_repository}:"
f"{tag}-intermediate-build"
)
if push and container_registry:
user_image_name = (
f"{container_registry.config.uri}/{user_image_name}"
)
parent_image = user_image_name
else:
# The image we'll build from the custom Dockerfile will be
# used directly, so we tag it with the requested target name.
user_image_name = target_image_name
build_context = build_context_class(
root=docker_settings.build_context_root
)
build_context.add_file(
source=docker_settings.dockerfile, destination="Dockerfile"
)
logger.info("Building Docker image `%s`.", user_image_name)
image_name_or_digest = image_builder.build(
image_name=user_image_name,
build_context=build_context,
docker_build_options=docker_settings.build_options,
container_registry=container_registry if push else None,
)
elif not requires_zenml_build:
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
raise ValueError(
"Unable to run a ZenML pipeline with the given Docker "
"settings: No Dockerfile or custom parent image "
"specified and no files will be copied or requirements "
"installed."
)
else:
# The parent image will be used directly to run the pipeline and
# needs to be tagged/pushed
docker_utils.tag_image(parent_image, target=target_image_name)
if container_registry:
image_name_or_digest = container_registry.push_image(
target_image_name
)
else:
image_name_or_digest = target_image_name
if requires_zenml_build:
logger.info("Building Docker image `%s`.", target_image_name)
# Leave the build context empty if we don't want to copy any files
build_context_root = (
source_utils.get_source_root_path()
if docker_settings.copy_files
else None
)
build_context = build_context_class(
root=build_context_root,
dockerignore_file=docker_settings.dockerignore,
)
requirements_file_names = self._add_requirements_files(
docker_settings=docker_settings,
build_context=build_context,
stack=stack,
)
apt_packages = docker_settings.apt_packages
if docker_settings.install_stack_requirements:
apt_packages += stack.apt_packages
if apt_packages:
logger.info(
"Including apt packages: %s",
", ".join(f"`{p}`" for p in apt_packages),
)
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
# The default parent image is static and doesn't require a pull
# each time
pull_parent_image = False
elif docker_settings.dockerfile and not container_registry:
# We built a custom parent image and there was no container
# registry in the stack to push to, this is a local image
pull_parent_image = False
else:
# If the image is local, we don't need to pull it. Otherwise
# we play it safe and always pull in case the user pushed a new
# image for the given name and tag
pull_parent_image = not docker_utils.is_local_image(
parent_image
)
build_options = {"pull": pull_parent_image, "rm": False}
dockerfile = self._generate_zenml_pipeline_dockerfile(
parent_image=parent_image,
docker_settings=docker_settings,
requirements_files=requirements_file_names,
apt_packages=apt_packages,
entrypoint=entrypoint,
)
build_context.add_file(destination="Dockerfile", source=dockerfile)
if docker_settings.copy_global_config:
with tempfile.TemporaryDirectory() as tmpdir:
GlobalConfiguration().copy_configuration(
tmpdir,
load_config_path=PurePosixPath(
DOCKER_IMAGE_ZENML_CONFIG_PATH
),
)
build_context.add_directory(
source=tmpdir,
destination=DOCKER_IMAGE_ZENML_CONFIG_DIR,
)
if extra_files:
for destination, source in extra_files.items():
build_context.add_file(
destination=destination, source=source
)
image_name_or_digest = image_builder.build(
image_name=target_image_name,
build_context=build_context,
docker_build_options=build_options,
container_registry=container_registry,
)
return image_name_or_digest
proxy_utils
Proxy design pattern utils.
make_proxy_class(interface, attribute)
Proxy class decorator.
Use this decorator to transform the decorated class into a proxy that
forwards all calls defined in the interface
interface to the attribute
class attribute that implements the same interface.
This class is useful in cases where you need to have a base class that acts as a proxy or facade for one or more other classes. Both the decorated class and the class attribute must inherit from the same ABC interface for this to work. Only regular methods are supported, not class methods or attributes.
Example: Let's say you have an interface called BodyBuilder
, a base class
called FatBob
and another class called BigJim
. BigJim
implements the
BodyBuilder
interface, but FatBob
does not. And let's say you want
FatBob
to look as if it implements the BodyBuilder
interface, but in
fact it just forwards all calls to BigJim
. You could do this:
from abc import ABC, abstractmethod
class BodyBuilder(ABC):
@abstractmethod
def build_body(self):
pass
class BigJim(BodyBuilder):
def build_body(self):
print("Looks fit!")
class FatBob(BodyBuilder)
def __init__(self):
self.big_jim = BigJim()
def build_body(self):
self.big_jim.build_body()
fat_bob = FatBob()
fat_bob.build_body()
But this leads to a lot of boilerplate code with bigger interfaces and makes everything harder to maintain. This is where the proxy class decorator comes in handy. Here's how to use it:
from zenml.utils.proxy_utils import make_proxy_class
from typing import Optional
@make_proxy_class(BodyBuilder, "big_jim")
class FatBob(BodyBuilder)
big_jim: Optional[BodyBuilder] = None
def __init__(self):
self.big_jim = BigJim()
fat_bob = FatBob()
fat_bob.build_body()
This is the same as implementing FatBob to call BigJim explicitly, but it has the advantage that you don't need to write a lot of boilerplate code of modify the FatBob class every time you change something in the BodyBuilder interface.
This proxy decorator also allows to extend classes dynamically at runtime:
if the attribute
class attribute is set to None, the proxy class
will assume that the interface is not implemented by the class and will
raise a NotImplementedError:
@make_proxy_class(BodyBuilder, "big_jim")
class FatBob(BodyBuilder)
big_jim: Optional[BodyBuilder] = None
def __init__(self):
self.big_jim = None
fat_bob = FatBob()
# Raises NotImplementedError, class not extended yet:
fat_bob.build_body()
fat_bob.big_jim = BigJim()
# Now it works:
fat_bob.build_body()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
interface |
Type[abc.ABC] |
The interface to implement. |
required |
attribute |
str |
The attribute of the base class to forward calls to. |
required |
Returns:
Type | Description |
---|---|
Callable[[~C], ~C] |
The proxy class. |
Source code in zenml/utils/proxy_utils.py
def make_proxy_class(interface: Type[ABC], attribute: str) -> Callable[[C], C]:
"""Proxy class decorator.
Use this decorator to transform the decorated class into a proxy that
forwards all calls defined in the `interface` interface to the `attribute`
class attribute that implements the same interface.
This class is useful in cases where you need to have a base class that acts
as a proxy or facade for one or more other classes. Both the decorated class
and the class attribute must inherit from the same ABC interface for this to
work. Only regular methods are supported, not class methods or attributes.
Example: Let's say you have an interface called `BodyBuilder`, a base class
called `FatBob` and another class called `BigJim`. `BigJim` implements the
`BodyBuilder` interface, but `FatBob` does not. And let's say you want
`FatBob` to look as if it implements the `BodyBuilder` interface, but in
fact it just forwards all calls to `BigJim`. You could do this:
```python
from abc import ABC, abstractmethod
class BodyBuilder(ABC):
@abstractmethod
def build_body(self):
pass
class BigJim(BodyBuilder):
def build_body(self):
print("Looks fit!")
class FatBob(BodyBuilder)
def __init__(self):
self.big_jim = BigJim()
def build_body(self):
self.big_jim.build_body()
fat_bob = FatBob()
fat_bob.build_body()
```
But this leads to a lot of boilerplate code with bigger interfaces and
makes everything harder to maintain. This is where the proxy class
decorator comes in handy. Here's how to use it:
```python
from zenml.utils.proxy_utils import make_proxy_class
from typing import Optional
@make_proxy_class(BodyBuilder, "big_jim")
class FatBob(BodyBuilder)
big_jim: Optional[BodyBuilder] = None
def __init__(self):
self.big_jim = BigJim()
fat_bob = FatBob()
fat_bob.build_body()
```
This is the same as implementing FatBob to call BigJim explicitly, but it
has the advantage that you don't need to write a lot of boilerplate code
of modify the FatBob class every time you change something in the
BodyBuilder interface.
This proxy decorator also allows to extend classes dynamically at runtime:
if the `attribute` class attribute is set to None, the proxy class
will assume that the interface is not implemented by the class and will
raise a NotImplementedError:
```python
@make_proxy_class(BodyBuilder, "big_jim")
class FatBob(BodyBuilder)
big_jim: Optional[BodyBuilder] = None
def __init__(self):
self.big_jim = None
fat_bob = FatBob()
# Raises NotImplementedError, class not extended yet:
fat_bob.build_body()
fat_bob.big_jim = BigJim()
# Now it works:
fat_bob.build_body()
```
Args:
interface: The interface to implement.
attribute: The attribute of the base class to forward calls to.
Returns:
The proxy class.
"""
def make_proxy_method(cls: C, _method: F) -> F:
"""Proxy method decorator.
Used to transform a method into a proxy that forwards all calls to the
given class attribute.
Args:
cls: The class to use as the base.
_method: The method to replace.
Returns:
The proxy method.
"""
@wraps(_method)
def proxy_method(*args: Any, **kw: Any) -> Any:
"""Proxy method.
Args:
*args: The arguments to pass to the method.
**kw: The keyword arguments to pass to the method.
Returns:
The return value of the proxied method.
Raises:
TypeError: If the class does not have the attribute specified
in the decorator or if the attribute does not implement
the specified interface.
NotImplementedError: If the attribute specified in the
decorator is None, i.e. the interface is not implemented.
"""
self = args[0]
if not hasattr(self, attribute):
raise TypeError(
f"Class '{cls.__name__}' does not have a '{attribute}' "
f"as specified in the 'make_proxy_class' decorator."
)
proxied_obj = getattr(self, attribute)
if proxied_obj is None:
raise NotImplementedError(
f"This '{cls.__name__}' instance does not implement the "
f"'{interface.__name__}' interface."
)
if not isinstance(proxied_obj, interface):
raise TypeError(
f"Interface '{interface.__name__}' must be implemented by "
f"the '{cls.__name__}' '{attribute}' attribute."
)
proxied_method = getattr(proxied_obj, _method.__name__)
return proxied_method(*args[1:], **kw)
return cast(F, proxy_method)
def _inner_decorator(_cls: C) -> C:
"""Inner proxy class decorator.
Args:
_cls: The class to decorate.
Returns:
The decorated class.
Raises:
TypeError: If the decorated class does not implement the specified
interface.
"""
if not issubclass(_cls, interface):
raise TypeError(
f"Interface '{interface.__name__}' must be implemented by "
f"the '{_cls.__name__}' class."
)
for method_name in interface.__abstractmethods__:
original_method = getattr(_cls, method_name)
method_proxy = make_proxy_method(_cls, original_method)
# Make sure the proxy method is not considered abstract.
method_proxy.__isabstractmethod__ = False
setattr(_cls, method_name, method_proxy)
# Remove the abstract methods in the interface from the decorated class.
_cls.__abstractmethods__ = frozenset(
method_name
for method_name in _cls.__abstractmethods__
if method_name not in interface.__abstractmethods__
)
return cast(C, _cls)
return _inner_decorator
pydantic_utils
Utilities for pydantic models.
TemplateGenerator
Class to generate templates for pydantic models or classes.
Source code in zenml/utils/pydantic_utils.py
class TemplateGenerator:
"""Class to generate templates for pydantic models or classes."""
def __init__(
self, instance_or_class: Union[BaseModel, Type[BaseModel]]
) -> None:
"""Initializes the template generator.
Args:
instance_or_class: The pydantic model or model class for which to
generate a template.
"""
self.instance_or_class = instance_or_class
def run(self) -> Dict[str, Any]:
"""Generates the template.
Returns:
The template dictionary.
"""
if isinstance(self.instance_or_class, BaseModel):
template = self._generate_template_for_model(
self.instance_or_class
)
else:
template = self._generate_template_for_model_class(
self.instance_or_class
)
# Convert to json in an intermediate step so we can leverage Pydantic's
# encoder to support types like UUID and datetime
json_string = json.dumps(template, default=pydantic_encoder)
return cast(Dict[str, Any], json.loads(json_string))
def _generate_template_for_model(self, model: BaseModel) -> Dict[str, Any]:
"""Generates a template for a pydantic model.
Args:
model: The model for which to generate the template.
Returns:
The model template.
"""
template = self._generate_template_for_model_class(model.__class__)
for name in model.__fields_set__:
value = getattr(model, name)
template[name] = self._generate_template_for_value(value)
return template
def _generate_template_for_model_class(
self,
model_class: Type[BaseModel],
) -> Dict[str, Any]:
"""Generates a template for a pydantic model class.
Args:
model_class: The model class for which to generate the template.
Returns:
The model class template.
"""
template: Dict[str, Any] = {}
for name, field in model_class.__fields__.items():
if self._is_model_class(field.outer_type_):
template[name] = self._generate_template_for_model_class(
field.outer_type_
)
elif field.outer_type_ is Optional and self._is_model_class(
field.type_
):
template[name] = self._generate_template_for_model_class(
field.type_
)
else:
template[name] = field._type_display()
return template
def _generate_template_for_value(self, value: Any) -> Any:
"""Generates a template for an arbitrary value.
Args:
value: The value for which to generate the template.
Returns:
The value template.
"""
if isinstance(value, Dict):
return {
k: self._generate_template_for_value(v)
for k, v in value.items()
}
elif sequence_like(value):
return [self._generate_template_for_value(v) for v in value]
elif isinstance(value, BaseModel):
return self._generate_template_for_model(value)
else:
return value
@staticmethod
def _is_model_class(value: Any) -> bool:
"""Checks if the given value is a pydantic model class.
Args:
value: The value to check.
Returns:
If the value is a pydantic model class.
"""
return isinstance(value, type) and issubclass(value, BaseModel)
__init__(self, instance_or_class)
special
Initializes the template generator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
instance_or_class |
Union[pydantic.main.BaseModel, Type[pydantic.main.BaseModel]] |
The pydantic model or model class for which to generate a template. |
required |
Source code in zenml/utils/pydantic_utils.py
def __init__(
self, instance_or_class: Union[BaseModel, Type[BaseModel]]
) -> None:
"""Initializes the template generator.
Args:
instance_or_class: The pydantic model or model class for which to
generate a template.
"""
self.instance_or_class = instance_or_class
run(self)
Generates the template.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The template dictionary. |
Source code in zenml/utils/pydantic_utils.py
def run(self) -> Dict[str, Any]:
"""Generates the template.
Returns:
The template dictionary.
"""
if isinstance(self.instance_or_class, BaseModel):
template = self._generate_template_for_model(
self.instance_or_class
)
else:
template = self._generate_template_for_model_class(
self.instance_or_class
)
# Convert to json in an intermediate step so we can leverage Pydantic's
# encoder to support types like UUID and datetime
json_string = json.dumps(template, default=pydantic_encoder)
return cast(Dict[str, Any], json.loads(json_string))
YAMLSerializationMixin (BaseModel)
pydantic-model
Class to serialize/deserialize pydantic models to/from YAML.
Source code in zenml/utils/pydantic_utils.py
class YAMLSerializationMixin(BaseModel):
"""Class to serialize/deserialize pydantic models to/from YAML."""
def yaml(self, sort_keys: bool = False, **kwargs: Any) -> str:
"""YAML string representation..
Args:
sort_keys: Whether to sort the keys in the YAML representation.
**kwargs: Kwargs to pass to the pydantic json(...) method.
Returns:
YAML string representation.
"""
dict_ = json.loads(self.json(**kwargs, sort_keys=sort_keys))
return cast(str, yaml.dump(dict_, sort_keys=sort_keys))
@classmethod
def from_yaml(cls: Type[M], path: str) -> M:
"""Creates an instance from a YAML file.
Args:
path: Path to a YAML file.
Returns:
The model instance.
"""
dict_ = yaml_utils.read_yaml(path)
return cls.parse_obj(dict_)
from_yaml(path)
classmethod
Creates an instance from a YAML file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Path to a YAML file. |
required |
Returns:
Type | Description |
---|---|
~M |
The model instance. |
Source code in zenml/utils/pydantic_utils.py
@classmethod
def from_yaml(cls: Type[M], path: str) -> M:
"""Creates an instance from a YAML file.
Args:
path: Path to a YAML file.
Returns:
The model instance.
"""
dict_ = yaml_utils.read_yaml(path)
return cls.parse_obj(dict_)
yaml(self, sort_keys=False, **kwargs)
YAML string representation..
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sort_keys |
bool |
Whether to sort the keys in the YAML representation. |
False |
**kwargs |
Any |
Kwargs to pass to the pydantic json(...) method. |
{} |
Returns:
Type | Description |
---|---|
str |
YAML string representation. |
Source code in zenml/utils/pydantic_utils.py
def yaml(self, sort_keys: bool = False, **kwargs: Any) -> str:
"""YAML string representation..
Args:
sort_keys: Whether to sort the keys in the YAML representation.
**kwargs: Kwargs to pass to the pydantic json(...) method.
Returns:
YAML string representation.
"""
dict_ = json.loads(self.json(**kwargs, sort_keys=sort_keys))
return cast(str, yaml.dump(dict_, sort_keys=sort_keys))
update_model(original, update, recursive=True, exclude_none=True)
Updates a pydantic model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
original |
~M |
The model to update. |
required |
update |
Union[BaseModel, Dict[str, Any]] |
The update values. |
required |
recursive |
bool |
If |
True |
exclude_none |
bool |
If |
True |
Returns:
Type | Description |
---|---|
~M |
The updated model. |
Source code in zenml/utils/pydantic_utils.py
def update_model(
original: M,
update: Union["BaseModel", Dict[str, Any]],
recursive: bool = True,
exclude_none: bool = True,
) -> M:
"""Updates a pydantic model.
Args:
original: The model to update.
update: The update values.
recursive: If `True`, dictionary values will be updated recursively.
exclude_none: If `True`, `None` values in the update dictionary
will be removed.
Returns:
The updated model.
"""
if isinstance(update, Dict):
if exclude_none:
update_dict = dict_utils.remove_none_values(
update, recursive=recursive
)
else:
update_dict = update
else:
update_dict = update.dict(exclude_unset=True)
original_dict = original.dict(exclude_unset=True)
if recursive:
values = dict_utils.recursive_update(original_dict, update_dict)
else:
values = {**original_dict, **update_dict}
return original.__class__(**values)
secret_utils
Utility functions for secrets and secret references.
SecretReference (tuple)
Class representing a secret reference.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
The secret name. |
key |
str |
The secret key. |
Source code in zenml/utils/secret_utils.py
class SecretReference(NamedTuple):
"""Class representing a secret reference.
Attributes:
name: The secret name.
key: The secret key.
"""
name: str
key: str
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/utils/secret_utils.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, name, key)
special
staticmethod
Create new instance of SecretReference(name, key)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/utils/secret_utils.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
ClearTextField(*args, **kwargs)
Marks a pydantic field to prevent secret references.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments which will be forwarded
to |
() |
**kwargs |
Any |
Keyword arguments which will be forwarded to
|
{} |
Returns:
Type | Description |
---|---|
Any |
Pydantic field info. |
Source code in zenml/utils/secret_utils.py
def ClearTextField(*args: Any, **kwargs: Any) -> Any:
"""Marks a pydantic field to prevent secret references.
Args:
*args: Positional arguments which will be forwarded
to `pydantic.Field(...)`.
**kwargs: Keyword arguments which will be forwarded to
`pydantic.Field(...)`.
Returns:
Pydantic field info.
"""
kwargs[PYDANTIC_CLEAR_TEXT_FIELD_MARKER] = True
return Field(*args, **kwargs)
SecretField(*args, **kwargs)
Marks a pydantic field as something containing sensitive information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments which will be forwarded
to |
() |
**kwargs |
Any |
Keyword arguments which will be forwarded to
|
{} |
Returns:
Type | Description |
---|---|
Any |
Pydantic field info. |
Source code in zenml/utils/secret_utils.py
def SecretField(*args: Any, **kwargs: Any) -> Any:
"""Marks a pydantic field as something containing sensitive information.
Args:
*args: Positional arguments which will be forwarded
to `pydantic.Field(...)`.
**kwargs: Keyword arguments which will be forwarded to
`pydantic.Field(...)`.
Returns:
Pydantic field info.
"""
kwargs[PYDANTIC_SENSITIVE_FIELD_MARKER] = True
return Field(*args, **kwargs)
is_clear_text_field(field)
Returns whether a pydantic field prevents secret references or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
field |
ModelField |
The field to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/secret_utils.py
def is_clear_text_field(field: "ModelField") -> bool:
"""Returns whether a pydantic field prevents secret references or not.
Args:
field: The field to check.
Returns:
`True` if the field prevents secret references, `False` otherwise.
"""
return field.field_info.extra.get(PYDANTIC_CLEAR_TEXT_FIELD_MARKER, False)
is_secret_field(field)
Returns whether a pydantic field contains sensitive information or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
field |
ModelField |
The field to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/secret_utils.py
def is_secret_field(field: "ModelField") -> bool:
"""Returns whether a pydantic field contains sensitive information or not.
Args:
field: The field to check.
Returns:
`True` if the field contains sensitive information, `False` otherwise.
"""
return field.field_info.extra.get(PYDANTIC_SENSITIVE_FIELD_MARKER, False)
is_secret_reference(value)
Checks whether any value is a secret reference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
The value to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/secret_utils.py
def is_secret_reference(value: Any) -> bool:
"""Checks whether any value is a secret reference.
Args:
value: The value to check.
Returns:
`True` if the value is a secret reference, `False` otherwise.
"""
if not isinstance(value, str):
return False
return bool(_secret_reference_expression.fullmatch(value))
parse_secret_reference(reference)
Parses a secret reference.
This function assumes the input string is a valid secret reference and does not perform any additional checks. If you pass an invalid secret reference here, this will most likely crash.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference |
str |
The string representing a valid secret reference. |
required |
Returns:
Type | Description |
---|---|
SecretReference |
The parsed secret reference. |
Source code in zenml/utils/secret_utils.py
def parse_secret_reference(reference: str) -> SecretReference:
"""Parses a secret reference.
This function assumes the input string is a valid secret reference and
**does not** perform any additional checks. If you pass an invalid secret
reference here, this will most likely crash.
Args:
reference: The string representing a **valid** secret reference.
Returns:
The parsed secret reference.
"""
reference = reference[2:]
reference = reference[:-2]
secret_name, secret_key = reference.split(".", 1)
return SecretReference(name=secret_name, key=secret_key)
settings_utils
Utility functions for ZenML settings.
get_flavor_setting_key(flavor)
Gets the setting key for a flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
Flavor |
The flavor for which to get the key. |
required |
Returns:
Type | Description |
---|---|
str |
The setting key for the flavor. |
Source code in zenml/utils/settings_utils.py
def get_flavor_setting_key(flavor: "Flavor") -> str:
"""Gets the setting key for a flavor.
Args:
flavor: The flavor for which to get the key.
Returns:
The setting key for the flavor.
"""
return f"{flavor.type}.{flavor.name}"
get_general_settings()
Returns all general settings.
Returns:
Type | Description |
---|---|
Dict[str, Type[BaseSettings]] |
Dictionary mapping general settings keys to their type. |
Source code in zenml/utils/settings_utils.py
def get_general_settings() -> Dict[str, Type["BaseSettings"]]:
"""Returns all general settings.
Returns:
Dictionary mapping general settings keys to their type.
"""
from zenml.config import DockerSettings, ResourceSettings
return {
DOCKER_SETTINGS_KEY: DockerSettings,
RESOURCE_SETTINGS_KEY: ResourceSettings,
}
get_stack_component_for_settings_key(key, stack)
Gets the stack component of a stack for a given settings key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The settings key for which to get the component. |
required |
stack |
Stack |
The stack from which to get the component. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the key is invalid or the stack does not contain a component of the correct flavor. |
Returns:
Type | Description |
---|---|
StackComponent |
The stack component. |
Source code in zenml/utils/settings_utils.py
def get_stack_component_for_settings_key(
key: str, stack: "Stack"
) -> "StackComponent":
"""Gets the stack component of a stack for a given settings key.
Args:
key: The settings key for which to get the component.
stack: The stack from which to get the component.
Raises:
ValueError: If the key is invalid or the stack does not contain a
component of the correct flavor.
Returns:
The stack component.
"""
if not is_stack_component_setting_key(key):
raise ValueError(
f"Settings key {key} does not refer to a stack component."
)
component_type, flavor = key.split(".", 1)
stack_component = stack.components.get(StackComponentType(component_type))
if not stack_component or stack_component.flavor != flavor:
raise ValueError(
f"Component of type {component_type} in stack {stack} is not "
f"of the flavor {flavor} specified by the settings key {key}."
)
return stack_component
get_stack_component_setting_key(stack_component)
Gets the setting key for a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_component |
StackComponent |
The stack component for which to get the key. |
required |
Returns:
Type | Description |
---|---|
str |
The setting key for the stack component. |
Source code in zenml/utils/settings_utils.py
def get_stack_component_setting_key(stack_component: "StackComponent") -> str:
"""Gets the setting key for a stack component.
Args:
stack_component: The stack component for which to get the key.
Returns:
The setting key for the stack component.
"""
return f"{stack_component.type}.{stack_component.flavor}"
is_general_setting_key(key)
Checks whether the key refers to a general setting.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the key refers to a general setting. |
Source code in zenml/utils/settings_utils.py
def is_general_setting_key(key: str) -> bool:
"""Checks whether the key refers to a general setting.
Args:
key: The key to check.
Returns:
If the key refers to a general setting.
"""
return key in get_general_settings()
is_stack_component_setting_key(key)
Checks whether a settings key refers to a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the key refers to a stack component. |
Source code in zenml/utils/settings_utils.py
def is_stack_component_setting_key(key: str) -> bool:
"""Checks whether a settings key refers to a stack component.
Args:
key: The key to check.
Returns:
If the key refers to a stack component.
"""
return bool(STACK_COMPONENT_REGEX.fullmatch(key))
is_valid_setting_key(key)
Checks whether a settings key is valid.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the key is valid. |
Source code in zenml/utils/settings_utils.py
def is_valid_setting_key(key: str) -> bool:
"""Checks whether a settings key is valid.
Args:
key: The key to check.
Returns:
If the key is valid.
"""
return is_general_setting_key(key) or is_stack_component_setting_key(key)
validate_setting_keys(setting_keys)
Validates settings keys.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
setting_keys |
Sequence[str] |
The keys to validate. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If any key is invalid. |
Source code in zenml/utils/settings_utils.py
def validate_setting_keys(setting_keys: Sequence[str]) -> None:
"""Validates settings keys.
Args:
setting_keys: The keys to validate.
Raises:
ValueError: If any key is invalid.
"""
for key in setting_keys:
if not is_valid_setting_key(key):
raise ValueError(
f"Invalid setting key `{key}`. Setting keys can either refer "
"to general settings (available keys: "
f"{set(get_general_settings())}) or stack component specific "
"settings. Stack component specific keys are of the format "
"`<STACK_COMPONENT_TYPE>.<STACK_COMPONENT_FLAVOR>`."
)
singleton
Utility class to turn classes into singleton classes.
SingletonMetaClass (type)
Singleton metaclass.
Use this metaclass to make any class into a singleton class:
class OneRing(metaclass=SingletonMetaClass):
def __init__(self, owner):
self._owner = owner
@property
def owner(self):
return self._owner
the_one_ring = OneRing('Sauron')
the_lost_ring = OneRing('Frodo')
print(the_lost_ring.owner) # Sauron
OneRing._clear() # ring destroyed
Source code in zenml/utils/singleton.py
class SingletonMetaClass(type):
"""Singleton metaclass.
Use this metaclass to make any class into a singleton class:
```python
class OneRing(metaclass=SingletonMetaClass):
def __init__(self, owner):
self._owner = owner
@property
def owner(self):
return self._owner
the_one_ring = OneRing('Sauron')
the_lost_ring = OneRing('Frodo')
print(the_lost_ring.owner) # Sauron
OneRing._clear() # ring destroyed
```
"""
def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize a singleton class.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
cls.__singleton_instance: Optional["SingletonMetaClass"] = None
def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
"""Create or return the singleton instance.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The singleton instance.
"""
if not cls.__singleton_instance:
cls.__singleton_instance = cast(
"SingletonMetaClass", super().__call__(*args, **kwargs)
)
return cls.__singleton_instance
def _clear(cls) -> None:
"""Clear the singleton instance."""
cls.__singleton_instance = None
__call__(cls, *args, **kwargs)
special
Create or return the singleton instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
SingletonMetaClass |
The singleton instance. |
Source code in zenml/utils/singleton.py
def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
"""Create or return the singleton instance.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The singleton instance.
"""
if not cls.__singleton_instance:
cls.__singleton_instance = cast(
"SingletonMetaClass", super().__call__(*args, **kwargs)
)
return cls.__singleton_instance
__init__(cls, *args, **kwargs)
special
Initialize a singleton class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/utils/singleton.py
def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize a singleton class.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
cls.__singleton_instance: Optional["SingletonMetaClass"] = None
source_utils
Utility functions for source code.
These utils are predicated on the following definitions:
- class_source: This is a python-import type path to a class, e.g. some.mod.class
- module_source: This is a python-import type path to a module, e.g. some.mod
- file_path, relative_path, absolute_path: These are file system paths.
- source: This is a class_source or module_source. If it is a class_source, it can also be optionally pinned.
- pin: Whatever comes after the
@
symbol from a source, usually the git sha or the version of zenml as a string.
get_hashed_source(value)
Returns a hash of the objects source code.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
object to get source from. |
required |
Returns:
Type | Description |
---|---|
str |
Hash of source code. |
Exceptions:
Type | Description |
---|---|
TypeError |
If unable to compute the hash. |
Source code in zenml/utils/source_utils.py
def get_hashed_source(value: Any) -> str:
"""Returns a hash of the objects source code.
Args:
value: object to get source from.
Returns:
Hash of source code.
Raises:
TypeError: If unable to compute the hash.
"""
try:
source_code = get_source(value)
except TypeError:
raise TypeError(
f"Unable to compute the hash of source code of object: {value}."
)
return hashlib.sha256(source_code.encode("utf-8")).hexdigest()
get_main_module_source()
Gets the source of the main module.
Returns:
Type | Description |
---|---|
str |
The main module source. |
Source code in zenml/utils/source_utils.py
def get_main_module_source() -> str:
"""Gets the source of the main module.
Returns:
The main module source.
"""
main_module = sys.modules["__main__"]
return get_module_source_from_module(main_module)
get_module_source_from_module(module)
Gets the source of the supplied module.
E.g.:
-
a
/home/myrepo/src/run.py
module running as the main module returnsrun
if no repository root is specified. -
a
/home/myrepo/src/run.py
module running as the main module returnssrc.run
if the repository root is configured in/home/myrepo
-
a
/home/myrepo/src/pipeline.py
module not running as the main module returnssrc.pipeline
if the repository root is configured in/home/myrepo
-
a
/home/myrepo/src/pipeline.py
module not running as the main module returnspipeline
if no repository root is specified and the main module is also in/home/myrepo/src
. -
a
/home/step.py
module not running as the main module returnsstep
if the CWD is /home and the repository root or the main module are in a different path (e.g./home/myrepo/src
).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
module |
the module to get the source of. |
required |
Returns:
Type | Description |
---|---|
str |
The source of the main module. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the module is not loaded from a file |
Source code in zenml/utils/source_utils.py
def get_module_source_from_module(module: ModuleType) -> str:
"""Gets the source of the supplied module.
E.g.:
* a `/home/myrepo/src/run.py` module running as the main module returns
`run` if no repository root is specified.
* a `/home/myrepo/src/run.py` module running as the main module returns
`src.run` if the repository root is configured in `/home/myrepo`
* a `/home/myrepo/src/pipeline.py` module not running as the main module
returns `src.pipeline` if the repository root is configured in
`/home/myrepo`
* a `/home/myrepo/src/pipeline.py` module not running as the main module
returns `pipeline` if no repository root is specified and the main
module is also in `/home/myrepo/src`.
* a `/home/step.py` module not running as the main module
returns `step` if the CWD is /home and the repository root or the main
module are in a different path (e.g. `/home/myrepo/src`).
Args:
module: the module to get the source of.
Returns:
The source of the main module.
Raises:
RuntimeError: if the module is not loaded from a file
"""
if not hasattr(module, "__file__") or not module.__file__:
if module.__name__ == "__main__":
raise RuntimeError(
f"{module} module was not loaded from a file. Cannot "
"determine the module root path."
)
return module.__name__
module_path = os.path.abspath(module.__file__)
root_path = get_source_root_path()
if not module_path.startswith(root_path):
logger.warning(
"User module %s is not in the source root %s. Using current "
"directory %s instead to resolve module source.",
module,
root_path,
os.getcwd(),
)
root_path = os.getcwd()
root_path = os.path.abspath(root_path)
# Remove root_path from module_path to get relative path left over
module_path = os.path.relpath(module_path, root_path)
if module_path.startswith(os.pardir):
raise RuntimeError(
f"Unable to resolve source for module {module}. The module file "
f"'{module_path}' does not seem to be inside the source root "
f"'{root_path}'."
)
# Remove the file extension and replace the os specific path separators
# with `.` to get the module source
module_path, file_extension = os.path.splitext(module_path)
if file_extension != ".py":
raise RuntimeError(
f"Unable to resolve source for module {module}. The module file "
f"'{module_path}' does not seem to be a python file."
)
module_source = module_path.replace(os.path.sep, ".")
logger.debug(
f"Resolved module source for module {module} to: `{module_source}`"
)
return module_source
get_source(value)
Returns the source code of an object.
If executing within a IPython kernel environment, then this monkey-patches
inspect
module temporarily with a workaround to get source from the cell.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
object to get source from. |
required |
Returns:
Type | Description |
---|---|
str |
Source code of object. |
Source code in zenml/utils/source_utils.py
def get_source(value: Any) -> str:
"""Returns the source code of an object.
If executing within a IPython kernel environment, then this monkey-patches
`inspect` module temporarily with a workaround to get source from the cell.
Args:
value: object to get source from.
Returns:
Source code of object.
"""
if Environment.in_notebook():
# Monkey patch inspect.getfile temporarily to make getsource work.
# Source: https://stackoverflow.com/questions/51566497/
def _new_getfile(
object: Any,
_old_getfile: Callable[
[
Union[
ModuleType,
Type[Any],
MethodType,
FunctionType,
TracebackType,
FrameType,
CodeType,
Callable[..., Any],
]
],
str,
] = inspect.getfile,
) -> Any:
if not inspect.isclass(object):
return _old_getfile(object)
# Lookup by parent module (as in current inspect)
if hasattr(object, "__module__"):
object_ = sys.modules.get(object.__module__)
if hasattr(object_, "__file__"):
return object_.__file__ # type: ignore[union-attr]
# If parent module is __main__, lookup by methods
for name, member in inspect.getmembers(object):
if (
inspect.isfunction(member)
and object.__qualname__ + "." + member.__name__
== member.__qualname__
):
return inspect.getfile(member)
else:
raise TypeError(f"Source for {object!r} not found.")
# Monkey patch, compute source, then revert monkey patch.
_old_getfile = inspect.getfile
inspect.getfile = _new_getfile
try:
src = inspect.getsource(value)
finally:
inspect.getfile = _old_getfile
else:
# Use standard inspect if running outside a notebook
src = inspect.getsource(value)
return src
get_source_root_path()
Gets repository root path or the source root path of the current process.
E.g.:
-
if the process was started by running a
run.py
file underfull/path/to/my/run.py
, and the repository root is configured atfull/path
, the source root path isfull/path
. -
same case as above, but when there is no repository root configured, the source root path is
full/path/to/my
.
Returns:
Type | Description |
---|---|
str |
The source root path of the current process. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the main module was not started or determined. |
Source code in zenml/utils/source_utils.py
def get_source_root_path() -> str:
"""Gets repository root path or the source root path of the current process.
E.g.:
* if the process was started by running a `run.py` file under
`full/path/to/my/run.py`, and the repository root is configured at
`full/path`, the source root path is `full/path`.
* same case as above, but when there is no repository root configured,
the source root path is `full/path/to/my`.
Returns:
The source root path of the current process.
Raises:
RuntimeError: if the main module was not started or determined.
"""
if _CUSTOM_SOURCE_ROOT:
return _CUSTOM_SOURCE_ROOT
from zenml.client import Client
repo_root = Client.find_repository()
if repo_root:
logger.debug("Using repository root as source root: %s", repo_root)
return str(repo_root.resolve())
main_module = sys.modules.get("__main__")
if main_module is None:
raise RuntimeError(
"Could not determine the main module used to run the current "
"process."
)
if not hasattr(main_module, "__file__") or not main_module.__file__:
raise RuntimeError(
"Main module was not started from a file. Cannot "
"determine the module root path."
)
path = pathlib.Path(main_module.__file__).resolve().parent
logger.debug("Using main module location as source root: %s", path)
return str(path)
import_by_path(path)
Imports a module attribute.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
The source path of the attribute to import, e.g.
|
required |
Returns:
Type | Description |
---|---|
Any |
The imported attribute. |
Source code in zenml/utils/source_utils.py
def import_by_path(path: str) -> Any:
"""Imports a module attribute.
Args:
path: The source path of the attribute to import, e.g.
`some.module.attribute_name`.
Returns:
The imported attribute.
"""
module_name, attribute_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, attribute_name)
import_class_by_path(class_path)
Imports a class based on a given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_path |
str |
Class source, e.g., |
required |
Returns:
Type | Description |
---|---|
Type[Any] |
the given class |
Source code in zenml/utils/source_utils.py
def import_class_by_path(class_path: str) -> Type[Any]:
"""Imports a class based on a given path.
Args:
class_path: Class source, e.g., `this.module.Class`
Returns:
the given class
"""
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name) # type: ignore[no-any-return]
import_python_file(file_path, zen_root)
Imports a python file in relationship to the zen root.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to python file that should be imported. |
required |
zen_root |
str |
Path to current zenml root |
required |
Returns:
Type | Description |
---|---|
imported module |
Module |
Source code in zenml/utils/source_utils.py
def import_python_file(file_path: str, zen_root: str) -> types.ModuleType:
"""Imports a python file in relationship to the zen root.
Args:
file_path: Path to python file that should be imported.
zen_root: Path to current zenml root
Returns:
imported module: Module
"""
file_path = os.path.abspath(file_path)
module_path = os.path.relpath(file_path, zen_root)
module_name = os.path.splitext(module_path)[0].replace(os.path.sep, ".")
if module_name in sys.modules:
del sys.modules[module_name]
# Add directory of python file to PYTHONPATH so we can import it
with prepend_python_path([zen_root]):
module = importlib.import_module(module_name)
return module
else:
# Add directory of python file to PYTHONPATH so we can import it
with prepend_python_path([zen_root]):
module = importlib.import_module(module_name)
return module
is_inside_repository(file_path)
Returns whether a file is inside a zenml repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
A file path. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_inside_repository(file_path: str) -> bool:
"""Returns whether a file is inside a zenml repository.
Args:
file_path: A file path.
Returns:
`True` if the file is inside a zenml repository, else `False`.
"""
from zenml.client import Client
repo_path = Client.find_repository()
if not repo_path:
return False
repo_path = repo_path.resolve()
absolute_file_path = pathlib.Path(file_path).resolve()
return repo_path in absolute_file_path.parents
is_internal_source(source)
Returns True
if source is an internal ZenML source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
Python source e.g. this.module.Class |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_internal_source(source: str) -> bool:
"""Returns `True` if source is an internal ZenML source.
Args:
source: Python source e.g. this.module.Class
Returns:
`True` if source is an internal ZenML source, else `False`.
"""
if source.split(".")[0] == "zenml":
return True
return False
is_third_party_module(file_path)
Returns whether a file belongs to a third party package.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
A file path. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_third_party_module(file_path: str) -> bool:
"""Returns whether a file belongs to a third party package.
Args:
file_path: A file path.
Returns:
`True` if the file belongs to a third party package, else `False`.
"""
absolute_file_path = pathlib.Path(file_path).resolve()
for path in site.getsitepackages() + [
site.getusersitepackages(),
get_python_lib(standard_lib=True),
]:
if pathlib.Path(path).resolve() in absolute_file_path.parents:
return True
return (
pathlib.Path(get_source_root_path()) not in absolute_file_path.parents
)
load_and_validate_class(source, expected_class)
Loads a source class and validates its type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
The source string. |
required |
expected_class |
Type[Any] |
The class that the source should resolve to. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If the source does not resolve to the expected type. |
Returns:
Type | Description |
---|---|
Type[Any] |
The resolved source class. |
Source code in zenml/utils/source_utils.py
def load_and_validate_class(
source: str, expected_class: Type[Any]
) -> Type[Any]:
"""Loads a source class and validates its type.
Args:
source: The source string.
expected_class: The class that the source should resolve to.
Raises:
TypeError: If the source does not resolve to the expected type.
Returns:
The resolved source class.
"""
class_ = load_source_path(source)
if isinstance(class_, type) and issubclass(class_, expected_class):
return class_
else:
raise TypeError(
f"Error while loading `{source}`. Expected class "
f"{expected_class.__name__}, got {class_} instead."
)
load_source_path(source, import_path=None)
Loads a python object from the source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
The source, e.g. this.module.Class |
required |
import_path |
Optional[str] |
optional path to add to python path |
None |
Returns:
Type | Description |
---|---|
Any |
The object located at the source path. |
Source code in zenml/utils/source_utils.py
def load_source_path(source: str, import_path: Optional[str] = None) -> Any:
"""Loads a python object from the source.
Args:
source: The source, e.g. this.module.Class
import_path: optional path to add to python path
Returns:
The object located at the source path.
"""
if not import_path:
source_root = get_source_root_path()
if source_root not in sys.path:
import_path = source_root
source = remove_internal_version_pin(source)
if import_path is not None:
with prepend_python_path([import_path]):
logger.debug(
f"Loading class {source} with import path {import_path}"
)
return import_by_path(source)
return import_by_path(source)
prepend_python_path(paths)
Simple context manager to help import module within the repo.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
paths |
List[str] |
paths to prepend to sys.path |
required |
Yields:
Type | Description |
---|---|
Iterator[NoneType] |
None |
Source code in zenml/utils/source_utils.py
@contextmanager
def prepend_python_path(paths: List[str]) -> Iterator[None]:
"""Simple context manager to help import module within the repo.
Args:
paths: paths to prepend to sys.path
Yields:
None
"""
try:
# Entering the with statement
for path in paths:
sys.path.insert(0, path)
yield
finally:
# Exiting the with statement
for path in paths:
sys.path.remove(path)
remove_internal_version_pin(source)
Removes an internal version pin of a source string.
This function returns the input source if no pin is found.
Examples:
zenml.client.Client@0.21.0
-> zenml.client.Client
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
The source from which to remove the pin. |
required |
Returns:
Type | Description |
---|---|
str |
The source with removed pin. |
Source code in zenml/utils/source_utils.py
def remove_internal_version_pin(source: str) -> str:
"""Removes an internal version pin of a source string.
This function returns the input source if no pin is found.
Example:
`zenml.client.Client@0.21.0` -> `zenml.client.Client`
Args:
source: The source from which to remove the pin.
Returns:
The source with removed pin.
"""
if "@" not in source:
return source
return source.split("@", 1)[0]
resolve_class(class_, replace_main_module=True)
Resolves a class into a serializable source string.
For classes that are not built-in nor imported from a Python package, the
get_source_root_path
function is used to determine the root path
relative to which the class source is resolved.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_ |
Type[Any] |
A Python Class reference. |
required |
replace_main_module |
bool |
If |
True |
Returns:
Type | Description |
---|---|
str |
source_path e.g. this.module.Class. |
Source code in zenml/utils/source_utils.py
def resolve_class(class_: Type[Any], replace_main_module: bool = True) -> str:
"""Resolves a class into a serializable source string.
For classes that are not built-in nor imported from a Python package, the
`get_source_root_path` function is used to determine the root path
relative to which the class source is resolved.
Args:
class_: A Python Class reference.
replace_main_module: If `True`, classes in the main module will have
the __main__ module source replaced with the source relative to
the ZenML source root.
Returns:
source_path e.g. this.module.Class.
"""
initial_source = class_.__module__ + "." + class_.__name__
if is_internal_source(initial_source):
return initial_source
try:
file_path = inspect.getfile(class_)
except (TypeError, OSError):
# builtin file
return initial_source
if class_.__module__ == "__main__":
if not replace_main_module:
return initial_source
# Resolve the __main__ module to something relative to the ZenML source
# root
return f"{get_main_module_source()}.{class_.__name__}"
if is_third_party_module(file_path):
return initial_source
# Regular user file -> get the full module path relative to the
# source root.
module_source = get_module_source_from_module(
sys.modules[class_.__module__]
)
source = module_source + "." + class_.__name__
logger.debug(f"Resolved class {class_} to `{source}`.")
return source
set_custom_source_root(source_root)
Sets a custom source root.
If set this has the highest priority and will always be used as the source root.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source_root |
Optional[str] |
The source root to use. |
required |
Source code in zenml/utils/source_utils.py
def set_custom_source_root(source_root: Optional[str]) -> None:
"""Sets a custom source root.
If set this has the highest priority and will always be used as the source
root.
Args:
source_root: The source root to use.
"""
logger.debug("Setting custom source root: %s", source_root)
global _CUSTOM_SOURCE_ROOT
_CUSTOM_SOURCE_ROOT = source_root
validate_config_source(source, component_type)
Validates a StackComponentConfig class from a given source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
source path of the implementation |
required |
component_type |
StackComponentType |
the type of the stack component |
required |
Returns:
Type | Description |
---|---|
Type[StackComponentConfig] |
The validated config. |
Exceptions:
Type | Description |
---|---|
ValueError |
If ZenML cannot import the config class. |
TypeError |
If the config class is not a subclass of the |
Source code in zenml/utils/source_utils.py
def validate_config_source(
source: str, component_type: StackComponentType
) -> Type["StackComponentConfig"]:
"""Validates a StackComponentConfig class from a given source.
Args:
source: source path of the implementation
component_type: the type of the stack component
Returns:
The validated config.
Raises:
ValueError: If ZenML cannot import the config class.
TypeError: If the config class is not a subclass of the `config_class`.
"""
from zenml.stack.stack_component import StackComponentConfig
try:
config_class = load_source_path(source)
except (ValueError, AttributeError, ImportError) as e:
raise ValueError(
f"ZenML can not import the config class '{source}': {e}"
)
if not (
inspect.isclass(config_class)
and issubclass(config_class, StackComponentConfig)
):
raise TypeError(
f"The source path '{source}' does not point to a subclass of "
f"the ZenML config_class."
)
return config_class
validate_flavor_source(source, component_type)
Import a StackComponent class from a given source and validate its type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
source path of the implementation |
required |
component_type |
StackComponentType |
the type of the stack component |
required |
Returns:
Type | Description |
---|---|
Type[Flavor] |
the imported class |
Exceptions:
Type | Description |
---|---|
ValueError |
If ZenML cannot find the given module path |
TypeError |
If the given module path does not point to a subclass of a StackComponent which has the right component type. |
Source code in zenml/utils/source_utils.py
def validate_flavor_source(
source: str, component_type: StackComponentType
) -> Type["Flavor"]:
"""Import a StackComponent class from a given source and validate its type.
Args:
source: source path of the implementation
component_type: the type of the stack component
Returns:
the imported class
Raises:
ValueError: If ZenML cannot find the given module path
TypeError: If the given module path does not point to a subclass of a
StackComponent which has the right component type.
"""
from zenml.stack.flavor import Flavor
from zenml.stack.stack_component import (
StackComponent,
StackComponentConfig,
)
try:
flavor_class = load_source_path(source)
except (ValueError, AttributeError, ImportError) as e:
raise ValueError(
f"ZenML can not import the flavor class '{source}': {e}"
)
if not (
inspect.isclass(flavor_class) and issubclass(flavor_class, Flavor)
):
raise TypeError(
f"The source '{source}' does not point to a subclass of the ZenML"
f"Flavor."
)
flavor = flavor_class()
try:
impl_class = flavor.implementation_class
except (ModuleNotFoundError, ImportError, NotImplementedError):
raise ValueError(
f"The implementation class defined within the "
f"'{flavor_class.__name__}' can not be imported."
)
if not issubclass(impl_class, StackComponent):
raise TypeError(
f"The implementation class '{impl_class.__name__}' of a flavor "
f"needs to be a subclass of the ZenML StackComponent."
)
if flavor.type != component_type: # noqa
raise TypeError(
f"The source points to a {impl_class.type}, not a " # noqa
f"{component_type}."
)
try:
conf_class = flavor.config_class
except (ModuleNotFoundError, ImportError, NotImplementedError):
raise ValueError(
f"The config class defined within the "
f"'{flavor_class.__name__}' can not be imported."
)
if not issubclass(conf_class, StackComponentConfig):
raise TypeError(
f"The config class '{conf_class.__name__}' of a flavor "
f"needs to be a subclass of the ZenML StackComponentConfig."
)
return flavor_class
validate_source_class(source, expected_class)
Validates that a source resolves to a certain type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
The source to validate. |
required |
expected_class |
Type[Any] |
The class that the source should resolve to. |
required |
Returns:
Type | Description |
---|---|
bool |
If the source resolves to the expected class. |
Source code in zenml/utils/source_utils.py
def validate_source_class(source: str, expected_class: Type[Any]) -> bool:
"""Validates that a source resolves to a certain type.
Args:
source: The source to validate.
expected_class: The class that the source should resolve to.
Returns:
If the source resolves to the expected class.
"""
try:
value = load_source_path(source)
except Exception:
return False
is_class = isinstance(value, type)
if is_class and issubclass(value, expected_class):
return True
else:
return False
string_utils
Utils for strings.
b64_decode(input_)
Returns a decoded string of the base 64 encoded input string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ |
str |
Base64 encoded string. |
required |
Returns:
Type | Description |
---|---|
str |
Decoded string. |
Source code in zenml/utils/string_utils.py
def b64_decode(input_: str) -> str:
"""Returns a decoded string of the base 64 encoded input string.
Args:
input_: Base64 encoded string.
Returns:
Decoded string.
"""
encoded_bytes = input_.encode()
decoded_bytes = base64.b64decode(encoded_bytes)
return decoded_bytes.decode()
b64_encode(input_)
Returns a base 64 encoded string of the input string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ |
str |
The input to encode. |
required |
Returns:
Type | Description |
---|---|
str |
Base64 encoded string. |
Source code in zenml/utils/string_utils.py
def b64_encode(input_: str) -> str:
"""Returns a base 64 encoded string of the input string.
Args:
input_: The input to encode.
Returns:
Base64 encoded string.
"""
input_bytes = input_.encode()
encoded_bytes = base64.b64encode(input_bytes)
return encoded_bytes.decode()
get_human_readable_filesize(bytes_)
Convert a file size in bytes into a human-readable string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bytes_ |
int |
The number of bytes to convert. |
required |
Returns:
Type | Description |
---|---|
str |
A human-readable string. |
Source code in zenml/utils/string_utils.py
def get_human_readable_filesize(bytes_: int) -> str:
"""Convert a file size in bytes into a human-readable string.
Args:
bytes_: The number of bytes to convert.
Returns:
A human-readable string.
"""
size = abs(float(bytes_))
for unit in ["B", "KiB", "MiB", "GiB"]:
if size < 1024.0 or unit == "GiB":
break
size /= 1024.0
return f"{size:.2f} {unit}"
get_human_readable_time(seconds)
Convert seconds into a human-readable string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seconds |
float |
The number of seconds to convert. |
required |
Returns:
Type | Description |
---|---|
str |
A human-readable string. |
Source code in zenml/utils/string_utils.py
def get_human_readable_time(seconds: float) -> str:
"""Convert seconds into a human-readable string.
Args:
seconds: The number of seconds to convert.
Returns:
A human-readable string.
"""
prefix = "-" if seconds < 0 else ""
seconds = abs(seconds)
int_seconds = int(seconds)
days, int_seconds = divmod(int_seconds, 86400)
hours, int_seconds = divmod(int_seconds, 3600)
minutes, int_seconds = divmod(int_seconds, 60)
if days > 0:
time_string = f"{days}d{hours}h{minutes}m{int_seconds}s"
elif hours > 0:
time_string = f"{hours}h{minutes}m{int_seconds}s"
elif minutes > 0:
time_string = f"{minutes}m{int_seconds}s"
else:
time_string = f"{seconds:.3f}s"
return prefix + time_string
random_str(length)
Generate a random human readable string of given length.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
length |
int |
Length of string |
required |
Returns:
Type | Description |
---|---|
str |
Random human-readable string. |
Source code in zenml/utils/string_utils.py
def random_str(length: int) -> str:
"""Generate a random human readable string of given length.
Args:
length: Length of string
Returns:
Random human-readable string.
"""
random.seed()
return "".join(random.choices(string.ascii_letters, k=length))
typed_model
Utility classes for adding type information to Pydantic models.
BaseTypedModel (BaseModel)
pydantic-model
Typed Pydantic model base class.
Use this class as a base class instead of BaseModel to automatically
add a type
literal attribute to the model that stores the name of the
class.
This can be useful when serializing models to JSON and then de-serializing them as part of a submodel union field, e.g.:
class BluePill(BaseTypedModel):
...
class RedPill(BaseTypedModel):
...
class TheMatrix(BaseTypedModel):
choice: Union[BluePill, RedPill] = Field(..., discriminator='type')
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = TheMatrix.parse_obj(d)
assert isinstance(new_matrix.choice, RedPill)
It can also facilitate de-serializing objects when their type isn't known:
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = BaseTypedModel.from_dict(d)
assert isinstance(new_matrix.choice, RedPill)
Source code in zenml/utils/typed_model.py
class BaseTypedModel(BaseModel, metaclass=BaseTypedModelMeta):
"""Typed Pydantic model base class.
Use this class as a base class instead of BaseModel to automatically
add a `type` literal attribute to the model that stores the name of the
class.
This can be useful when serializing models to JSON and then de-serializing
them as part of a submodel union field, e.g.:
```python
class BluePill(BaseTypedModel):
...
class RedPill(BaseTypedModel):
...
class TheMatrix(BaseTypedModel):
choice: Union[BluePill, RedPill] = Field(..., discriminator='type')
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = TheMatrix.parse_obj(d)
assert isinstance(new_matrix.choice, RedPill)
```
It can also facilitate de-serializing objects when their type isn't known:
```python
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = BaseTypedModel.from_dict(d)
assert isinstance(new_matrix.choice, RedPill)
```
"""
@classmethod
def from_dict(
cls,
model_dict: Dict[str, Any],
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON-able dict representation.
Args:
model_dict: the model attributes serialized as JSON-able dict.
Returns:
A BaseTypedModel created from the serialized representation.
Raises:
RuntimeError: if the model_dict contains an invalid type.
"""
model_type = model_dict.get("type")
if not model_type:
raise RuntimeError(
"`type` information is missing from the serialized model dict."
)
cls = load_source_path(model_type)
if not issubclass(cls, BaseTypedModel):
raise RuntimeError(
f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
)
return cls.parse_obj(model_dict)
@classmethod
def from_json(
cls,
json_str: str,
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON representation.
Args:
json_str: the model attributes serialized as JSON.
Returns:
A BaseTypedModel created from the serialized representation.
"""
model_dict = json.loads(json_str)
return cls.from_dict(model_dict)
from_dict(model_dict)
classmethod
Instantiate a Pydantic model from a serialized JSON-able dict representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_dict |
Dict[str, Any] |
the model attributes serialized as JSON-able dict. |
required |
Returns:
Type | Description |
---|---|
BaseTypedModel |
A BaseTypedModel created from the serialized representation. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the model_dict contains an invalid type. |
Source code in zenml/utils/typed_model.py
@classmethod
def from_dict(
cls,
model_dict: Dict[str, Any],
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON-able dict representation.
Args:
model_dict: the model attributes serialized as JSON-able dict.
Returns:
A BaseTypedModel created from the serialized representation.
Raises:
RuntimeError: if the model_dict contains an invalid type.
"""
model_type = model_dict.get("type")
if not model_type:
raise RuntimeError(
"`type` information is missing from the serialized model dict."
)
cls = load_source_path(model_type)
if not issubclass(cls, BaseTypedModel):
raise RuntimeError(
f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
)
return cls.parse_obj(model_dict)
from_json(json_str)
classmethod
Instantiate a Pydantic model from a serialized JSON representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
json_str |
str |
the model attributes serialized as JSON. |
required |
Returns:
Type | Description |
---|---|
BaseTypedModel |
A BaseTypedModel created from the serialized representation. |
Source code in zenml/utils/typed_model.py
@classmethod
def from_json(
cls,
json_str: str,
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON representation.
Args:
json_str: the model attributes serialized as JSON.
Returns:
A BaseTypedModel created from the serialized representation.
"""
model_dict = json.loads(json_str)
return cls.from_dict(model_dict)
BaseTypedModelMeta (ModelMetaclass)
Metaclass responsible for adding type information to Pydantic models.
Source code in zenml/utils/typed_model.py
class BaseTypedModelMeta(ModelMetaclass):
"""Metaclass responsible for adding type information to Pydantic models."""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseTypedModelMeta":
"""Creates a Pydantic BaseModel class.
This includes a hidden attribute that reflects the full class
identifier.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The class dictionary.
Returns:
A Pydantic BaseModel class that includes a hidden attribute that
reflects the full class identifier.
Raises:
TypeError: If the class is not a Pydantic BaseModel class.
"""
if "type" in dct:
raise TypeError(
"`type` is a reserved attribute name for BaseTypedModel "
"subclasses"
)
type_name = f"{dct['__module__']}.{dct['__qualname__']}"
type_ann = Literal[type_name] # type: ignore[valid-type]
type = Field(type_name)
dct.setdefault("__annotations__", dict())["type"] = type_ann
dct["type"] = type
cls = cast(
Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Creates a Pydantic BaseModel class.
This includes a hidden attribute that reflects the full class identifier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class. |
required |
dct |
Dict[str, Any] |
The class dictionary. |
required |
Returns:
Type | Description |
---|---|
BaseTypedModelMeta |
A Pydantic BaseModel class that includes a hidden attribute that reflects the full class identifier. |
Exceptions:
Type | Description |
---|---|
TypeError |
If the class is not a Pydantic BaseModel class. |
Source code in zenml/utils/typed_model.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseTypedModelMeta":
"""Creates a Pydantic BaseModel class.
This includes a hidden attribute that reflects the full class
identifier.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The class dictionary.
Returns:
A Pydantic BaseModel class that includes a hidden attribute that
reflects the full class identifier.
Raises:
TypeError: If the class is not a Pydantic BaseModel class.
"""
if "type" in dct:
raise TypeError(
"`type` is a reserved attribute name for BaseTypedModel "
"subclasses"
)
type_name = f"{dct['__module__']}.{dct['__qualname__']}"
type_ann = Literal[type_name] # type: ignore[valid-type]
type = Field(type_name)
dct.setdefault("__annotations__", dict())["type"] = type_ann
dct["type"] = type
cls = cast(
Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
)
return cls
uuid_utils
Utility functions for handling UUIDs.
generate_uuid_from_string(value)
Deterministically generates a UUID from a string seed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
str |
The string from which to generate the UUID. |
required |
Returns:
Type | Description |
---|---|
UUID |
The generated UUID. |
Source code in zenml/utils/uuid_utils.py
def generate_uuid_from_string(value: str) -> UUID:
"""Deterministically generates a UUID from a string seed.
Args:
value: The string from which to generate the UUID.
Returns:
The generated UUID.
"""
hash_ = hashlib.md5()
hash_.update(value.encode("utf-8"))
return UUID(hex=hash_.hexdigest(), version=4)
is_valid_uuid(value, version=4)
Checks if a string is a valid UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
String to check. |
required |
version |
int |
Version of UUID to check for. |
4 |
Returns:
Type | Description |
---|---|
bool |
True if string is a valid UUID, False otherwise. |
Source code in zenml/utils/uuid_utils.py
def is_valid_uuid(value: Any, version: int = 4) -> bool:
"""Checks if a string is a valid UUID.
Args:
value: String to check.
version: Version of UUID to check for.
Returns:
True if string is a valid UUID, False otherwise.
"""
if isinstance(value, UUID):
return True
if isinstance(value, str):
try:
UUID(value, version=version)
return True
except ValueError:
return False
return False
parse_name_or_uuid(name_or_id)
Convert a "name or id" string value to a string or UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name_or_id |
Optional[str] |
Name or id to convert. |
required |
Returns:
Type | Description |
---|---|
Union[str, uuid.UUID] |
A UUID if name_or_id is a UUID, string otherwise. |
Source code in zenml/utils/uuid_utils.py
def parse_name_or_uuid(
name_or_id: Optional[str],
) -> Optional[Union[str, UUID]]:
"""Convert a "name or id" string value to a string or UUID.
Args:
name_or_id: Name or id to convert.
Returns:
A UUID if name_or_id is a UUID, string otherwise.
"""
if name_or_id:
try:
return UUID(name_or_id)
except ValueError:
return name_or_id
else:
return name_or_id
yaml_utils
Utility functions to help with YAML files and data.
UUIDEncoder (JSONEncoder)
JSON encoder for UUID objects.
Source code in zenml/utils/yaml_utils.py
class UUIDEncoder(json.JSONEncoder):
"""JSON encoder for UUID objects."""
def default(self, obj: Any) -> Any:
"""Default UUID encoder for JSON.
Args:
obj: Object to encode.
Returns:
Encoded object.
"""
if isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return obj.hex
return json.JSONEncoder.default(self, obj)
default(self, obj)
Default UUID encoder for JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
Object to encode. |
required |
Returns:
Type | Description |
---|---|
Any |
Encoded object. |
Source code in zenml/utils/yaml_utils.py
def default(self, obj: Any) -> Any:
"""Default UUID encoder for JSON.
Args:
obj: Object to encode.
Returns:
Encoded object.
"""
if isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return obj.hex
return json.JSONEncoder.default(self, obj)
append_yaml(file_path, contents)
Append contents to a YAML file at file_path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
contents |
Dict[Any, Any] |
Contents of YAML file as dict. |
required |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if directory does not exist. |
Source code in zenml/utils/yaml_utils.py
def append_yaml(file_path: str, contents: Dict[Any, Any]) -> None:
"""Append contents to a YAML file at file_path.
Args:
file_path: Path to YAML file.
contents: Contents of YAML file as dict.
Raises:
FileNotFoundError: if directory does not exist.
"""
file_contents = read_yaml(file_path) or {}
file_contents.update(contents)
if not io_utils.is_remote(file_path):
dir_ = str(Path(file_path).parent)
if not fileio.isdir(dir_):
raise FileNotFoundError(f"Directory {dir_} does not exist.")
io_utils.write_file_contents_as_string(file_path, yaml.dump(file_contents))
comment_out_yaml(yaml_string)
Comments out a yaml string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_string |
str |
The yaml string to comment out. |
required |
Returns:
Type | Description |
---|---|
str |
The commented out yaml string. |
Source code in zenml/utils/yaml_utils.py
def comment_out_yaml(yaml_string: str) -> str:
"""Comments out a yaml string.
Args:
yaml_string: The yaml string to comment out.
Returns:
The commented out yaml string.
"""
lines = yaml_string.splitlines(keepends=True)
lines = ["# " + line for line in lines]
return "".join(lines)
is_yaml(file_path)
Returns True if file_path is YAML, else False.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
Returns:
Type | Description |
---|---|
bool |
True if is yaml, else False. |
Source code in zenml/utils/yaml_utils.py
def is_yaml(file_path: str) -> bool:
"""Returns True if file_path is YAML, else False.
Args:
file_path: Path to YAML file.
Returns:
True if is yaml, else False.
"""
if file_path.endswith("yaml") or file_path.endswith("yml"):
return True
return False
read_json(file_path)
Read JSON on file path and returns contents as dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to JSON file. |
required |
Returns:
Type | Description |
---|---|
Any |
Contents of the file in a dict. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if file does not exist. |
Source code in zenml/utils/yaml_utils.py
def read_json(file_path: str) -> Any:
"""Read JSON on file path and returns contents as dict.
Args:
file_path: Path to JSON file.
Returns:
Contents of the file in a dict.
Raises:
FileNotFoundError: if file does not exist.
"""
if fileio.exists(file_path):
contents = io_utils.read_file_contents_as_string(file_path)
return json.loads(contents)
else:
raise FileNotFoundError(f"{file_path} does not exist.")
read_yaml(file_path)
Read YAML on file path and returns contents as dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
Returns:
Type | Description |
---|---|
Any |
Contents of the file in a dict. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if file does not exist. |
Source code in zenml/utils/yaml_utils.py
def read_yaml(file_path: str) -> Any:
"""Read YAML on file path and returns contents as dict.
Args:
file_path: Path to YAML file.
Returns:
Contents of the file in a dict.
Raises:
FileNotFoundError: if file does not exist.
"""
if fileio.exists(file_path):
contents = io_utils.read_file_contents_as_string(file_path)
# TODO: [LOW] consider adding a default empty dict to be returned
# instead of None
return yaml.safe_load(contents)
else:
raise FileNotFoundError(f"{file_path} does not exist.")
write_json(file_path, contents, encoder=None)
Write contents as JSON format to file_path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to JSON file. |
required |
contents |
Any |
Contents of JSON file. |
required |
encoder |
Optional[Type[json.encoder.JSONEncoder]] |
Custom JSON encoder to use when saving json. |
None |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if directory does not exist. |
Source code in zenml/utils/yaml_utils.py
def write_json(
file_path: str,
contents: Any,
encoder: Optional[Type[json.JSONEncoder]] = None,
) -> None:
"""Write contents as JSON format to file_path.
Args:
file_path: Path to JSON file.
contents: Contents of JSON file.
encoder: Custom JSON encoder to use when saving json.
Raises:
FileNotFoundError: if directory does not exist.
"""
if not io_utils.is_remote(file_path):
dir_ = str(Path(file_path).parent)
if not fileio.isdir(dir_):
# Check if it is a local path, if it doesn't exist, raise Exception.
raise FileNotFoundError(f"Directory {dir_} does not exist.")
io_utils.write_file_contents_as_string(
file_path,
json.dumps(
contents,
cls=encoder,
),
)
write_yaml(file_path, contents, sort_keys=True)
Write contents as YAML format to file_path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
contents |
Union[Dict[Any, Any], List[Any]] |
Contents of YAML file as dict or list. |
required |
sort_keys |
bool |
If |
True |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if directory does not exist. |
Source code in zenml/utils/yaml_utils.py
def write_yaml(
file_path: str,
contents: Union[Dict[Any, Any], List[Any]],
sort_keys: bool = True,
) -> None:
"""Write contents as YAML format to file_path.
Args:
file_path: Path to YAML file.
contents: Contents of YAML file as dict or list.
sort_keys: If `True`, keys are sorted alphabetically. If `False`,
the order in which the keys were inserted into the dict will
be preserved.
Raises:
FileNotFoundError: if directory does not exist.
"""
if not io_utils.is_remote(file_path):
dir_ = str(Path(file_path).parent)
if not fileio.isdir(dir_):
raise FileNotFoundError(f"Directory {dir_} does not exist.")
io_utils.write_file_contents_as_string(
file_path, yaml.dump(contents, sort_keys=sort_keys)
)