Materializers
zenml.materializers
special
Initialization of ZenML materializers.
Materializers are used to convert a ZenML artifact into a specific format. They
are most often used to handle the input or output of ZenML steps, and can be
extended by building on the BaseMaterializer
class.
base_materializer
Metaclass implementation for registering ZenML BaseMaterializer subclasses.
BaseMaterializer
Base Materializer to realize artifact data.
Source code in zenml/materializers/base_materializer.py
class BaseMaterializer(metaclass=BaseMaterializerMeta):
"""Base Materializer to realize artifact data."""
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.BASE
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ()
# `SKIP_REGISTRATION` can be set to True to not register the class in the
# materializer registry. This is primarily useful for defining base classes.
# Subclasses will automatically have this set to False unless they override
# it themselves.
SKIP_REGISTRATION: ClassVar[bool] = True
_DOCS_BUILDING_MODE: ClassVar[bool] = False
def __init__(self, uri: str):
"""Initializes a materializer with the given URI.
Args:
uri: The URI where the artifact data will be stored.
"""
self.uri = uri
# ================
# Public Interface
# ================
def load(self, data_type: Type[Any]) -> Any:
"""Write logic here to load the data of an artifact.
Args:
data_type: What type the artifact data should be loaded as.
Returns:
The data of the artifact.
"""
# read from a location inside self.uri
return None
def save(self, data: Any) -> None:
"""Write logic here to save the data of an artifact.
Args:
data: The data of the artifact to save.
"""
# write `data` into self.uri
def save_visualizations(self, data: Any) -> Dict[str, VisualizationType]:
"""Save visualizations of the given data.
If this method is not overridden, no visualizations will be saved.
When overriding this method, make sure to save all visualizations to
files within `self.uri`.
Example:
```
artifact_store = Client().active_stack.artifact_store
visualization_uri = os.path.join(self.uri, "visualization.html")
with artifact_store.open(visualization_uri, "w") as f:
f.write("<html><body>data</body></html>")
visualization_uri_2 = os.path.join(self.uri, "visualization.png")
data.save_as_png(visualization_uri_2)
return {
visualization_uri: ArtifactVisualizationType.HTML,
visualization_uri_2: ArtifactVisualizationType.IMAGE
}
```
Args:
data: The data of the artifact to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
# Optionally, save some visualizations of `data` inside `self.uri`.
return {}
def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given data.
This metadata will be tracked and displayed alongside the artifact.
Example:
```
return {
"some_attribute_i_want_to_track": self.some_attribute,
"pi": 3.14,
}
```
Args:
data: The data to extract metadata from.
Returns:
A dictionary of metadata.
"""
# Optionally, extract some metadata from `data` for ZenML to store.
return {}
# ================
# Internal Methods
# ================
def validate_type_compatibility(self, data_type: Type[Any]) -> None:
"""Checks whether the materializer can read/write the given type.
Args:
data_type: The type to check.
Raises:
TypeError: If the materializer cannot read/write the given type.
"""
if not self.can_handle_type(data_type):
raise TypeError(
f"Unable to handle type {data_type}. {self.__class__.__name__} "
f"can only read/write artifacts of the following types: "
f"{self.ASSOCIATED_TYPES}."
)
@classmethod
def can_handle_type(cls, data_type: Type[Any]) -> bool:
"""Whether the materializer can read/write a certain type.
Args:
data_type: The type to check.
Returns:
Whether the materializer can read/write the given type.
"""
return any(
issubclass(data_type, associated_type)
for associated_type in cls.ASSOCIATED_TYPES
)
def extract_full_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract both base and custom metadata from the given data.
Args:
data: The data to extract metadata from.
Returns:
A dictionary of metadata.
"""
base_metadata = self._extract_base_metadata(data)
custom_metadata = self.extract_metadata(data)
return {**base_metadata, **custom_metadata}
def _extract_base_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given data.
This metadata will be extracted for all artifacts in addition to the
metadata extracted by the `extract_metadata` method.
Args:
data: The data to extract metadata from.
Returns:
A dictionary of metadata.
"""
from zenml.metadata.metadata_types import StorageSize
storage_size = fileio.size(self.uri)
if isinstance(storage_size, int):
return {"storage_size": StorageSize(storage_size)}
return {}
__init__(self, uri)
special
Initializes a materializer with the given URI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI where the artifact data will be stored. |
required |
Source code in zenml/materializers/base_materializer.py
def __init__(self, uri: str):
"""Initializes a materializer with the given URI.
Args:
uri: The URI where the artifact data will be stored.
"""
self.uri = uri
can_handle_type(data_type)
classmethod
Whether the materializer can read/write a certain type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the materializer can read/write the given type. |
Source code in zenml/materializers/base_materializer.py
@classmethod
def can_handle_type(cls, data_type: Type[Any]) -> bool:
"""Whether the materializer can read/write a certain type.
Args:
data_type: The type to check.
Returns:
Whether the materializer can read/write the given type.
"""
return any(
issubclass(data_type, associated_type)
for associated_type in cls.ASSOCIATED_TYPES
)
extract_full_metadata(self, data)
Extract both base and custom metadata from the given data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The data to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/materializers/base_materializer.py
def extract_full_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract both base and custom metadata from the given data.
Args:
data: The data to extract metadata from.
Returns:
A dictionary of metadata.
"""
base_metadata = self._extract_base_metadata(data)
custom_metadata = self.extract_metadata(data)
return {**base_metadata, **custom_metadata}
extract_metadata(self, data)
Extract metadata from the given data.
This metadata will be tracked and displayed alongside the artifact.
Examples:
return {
"some_attribute_i_want_to_track": self.some_attribute,
"pi": 3.14,
}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The data to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/materializers/base_materializer.py
def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given data.
This metadata will be tracked and displayed alongside the artifact.
Example:
```
return {
"some_attribute_i_want_to_track": self.some_attribute,
"pi": 3.14,
}
```
Args:
data: The data to extract metadata from.
Returns:
A dictionary of metadata.
"""
# Optionally, extract some metadata from `data` for ZenML to store.
return {}
load(self, data_type)
Write logic here to load the data of an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
What type the artifact data should be loaded as. |
required |
Returns:
Type | Description |
---|---|
Any |
The data of the artifact. |
Source code in zenml/materializers/base_materializer.py
def load(self, data_type: Type[Any]) -> Any:
"""Write logic here to load the data of an artifact.
Args:
data_type: What type the artifact data should be loaded as.
Returns:
The data of the artifact.
"""
# read from a location inside self.uri
return None
save(self, data)
Write logic here to save the data of an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The data of the artifact to save. |
required |
Source code in zenml/materializers/base_materializer.py
def save(self, data: Any) -> None:
"""Write logic here to save the data of an artifact.
Args:
data: The data of the artifact to save.
"""
# write `data` into self.uri
save_visualizations(self, data)
Save visualizations of the given data.
If this method is not overridden, no visualizations will be saved.
When overriding this method, make sure to save all visualizations to
files within self.uri
.
Examples:
artifact_store = Client().active_stack.artifact_store
visualization_uri = os.path.join(self.uri, "visualization.html")
with artifact_store.open(visualization_uri, "w") as f:
f.write("<html><body>data</body></html>")
visualization_uri_2 = os.path.join(self.uri, "visualization.png")
data.save_as_png(visualization_uri_2)
return {
visualization_uri: ArtifactVisualizationType.HTML,
visualization_uri_2: ArtifactVisualizationType.IMAGE
}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The data of the artifact to visualize. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.enums.VisualizationType] |
A dictionary of visualization URIs and their types. |
Source code in zenml/materializers/base_materializer.py
def save_visualizations(self, data: Any) -> Dict[str, VisualizationType]:
"""Save visualizations of the given data.
If this method is not overridden, no visualizations will be saved.
When overriding this method, make sure to save all visualizations to
files within `self.uri`.
Example:
```
artifact_store = Client().active_stack.artifact_store
visualization_uri = os.path.join(self.uri, "visualization.html")
with artifact_store.open(visualization_uri, "w") as f:
f.write("<html><body>data</body></html>")
visualization_uri_2 = os.path.join(self.uri, "visualization.png")
data.save_as_png(visualization_uri_2)
return {
visualization_uri: ArtifactVisualizationType.HTML,
visualization_uri_2: ArtifactVisualizationType.IMAGE
}
```
Args:
data: The data of the artifact to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
# Optionally, save some visualizations of `data` inside `self.uri`.
return {}
validate_type_compatibility(self, data_type)
Checks whether the materializer can read/write the given type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type to check. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If the materializer cannot read/write the given type. |
Source code in zenml/materializers/base_materializer.py
def validate_type_compatibility(self, data_type: Type[Any]) -> None:
"""Checks whether the materializer can read/write the given type.
Args:
data_type: The type to check.
Raises:
TypeError: If the materializer cannot read/write the given type.
"""
if not self.can_handle_type(data_type):
raise TypeError(
f"Unable to handle type {data_type}. {self.__class__.__name__} "
f"can only read/write artifacts of the following types: "
f"{self.ASSOCIATED_TYPES}."
)
BaseMaterializerMeta (type)
Metaclass responsible for registering different BaseMaterializer subclasses.
Materializers are used for reading/writing artifacts.
Source code in zenml/materializers/base_materializer.py
class BaseMaterializerMeta(type):
"""Metaclass responsible for registering different BaseMaterializer subclasses.
Materializers are used for reading/writing artifacts.
"""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseMaterializerMeta":
"""Creates a Materializer class and registers it at the `MaterializerRegistry`.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The dictionary of the class.
Returns:
The BaseMaterializerMeta class.
Raises:
MaterializerInterfaceError: If the class was improperly defined.
"""
cls = cast(
Type["BaseMaterializer"], super().__new__(mcs, name, bases, dct)
)
if not cls._DOCS_BUILDING_MODE:
# Skip the following validation and registration for base classes.
if cls.SKIP_REGISTRATION:
# Reset the flag so subclasses don't have it set automatically.
cls.SKIP_REGISTRATION = False
return cls
# Validate that the class is properly defined.
if not cls.ASSOCIATED_TYPES:
raise MaterializerInterfaceError(
f"Invalid materializer class '{name}'. When creating a "
f"custom materializer, make sure to specify at least one "
f"type in its ASSOCIATED_TYPES class variable.",
url="https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types",
)
# Validate associated artifact type.
if cls.ASSOCIATED_ARTIFACT_TYPE:
try:
cls.ASSOCIATED_ARTIFACT_TYPE = ArtifactType(
cls.ASSOCIATED_ARTIFACT_TYPE
)
except ValueError:
raise MaterializerInterfaceError(
f"Invalid materializer class '{name}'. When creating a "
f"custom materializer, make sure to specify a valid "
f"artifact type in its ASSOCIATED_ARTIFACT_TYPE class "
f"variable.",
url="https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types",
)
# Validate associated data types.
for associated_type in cls.ASSOCIATED_TYPES:
if not inspect.isclass(associated_type):
raise MaterializerInterfaceError(
f"Associated type {associated_type} for materializer "
f"{name} is not a class.",
url="https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types",
)
# Register the materializer.
for associated_type in cls.ASSOCIATED_TYPES:
materializer_registry.register_materializer_type(
associated_type, cls
)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Creates a Materializer class and registers it at the MaterializerRegistry
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class. |
required |
dct |
Dict[str, Any] |
The dictionary of the class. |
required |
Returns:
Type | Description |
---|---|
BaseMaterializerMeta |
The BaseMaterializerMeta class. |
Exceptions:
Type | Description |
---|---|
MaterializerInterfaceError |
If the class was improperly defined. |
Source code in zenml/materializers/base_materializer.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseMaterializerMeta":
"""Creates a Materializer class and registers it at the `MaterializerRegistry`.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The dictionary of the class.
Returns:
The BaseMaterializerMeta class.
Raises:
MaterializerInterfaceError: If the class was improperly defined.
"""
cls = cast(
Type["BaseMaterializer"], super().__new__(mcs, name, bases, dct)
)
if not cls._DOCS_BUILDING_MODE:
# Skip the following validation and registration for base classes.
if cls.SKIP_REGISTRATION:
# Reset the flag so subclasses don't have it set automatically.
cls.SKIP_REGISTRATION = False
return cls
# Validate that the class is properly defined.
if not cls.ASSOCIATED_TYPES:
raise MaterializerInterfaceError(
f"Invalid materializer class '{name}'. When creating a "
f"custom materializer, make sure to specify at least one "
f"type in its ASSOCIATED_TYPES class variable.",
url="https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types",
)
# Validate associated artifact type.
if cls.ASSOCIATED_ARTIFACT_TYPE:
try:
cls.ASSOCIATED_ARTIFACT_TYPE = ArtifactType(
cls.ASSOCIATED_ARTIFACT_TYPE
)
except ValueError:
raise MaterializerInterfaceError(
f"Invalid materializer class '{name}'. When creating a "
f"custom materializer, make sure to specify a valid "
f"artifact type in its ASSOCIATED_ARTIFACT_TYPE class "
f"variable.",
url="https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types",
)
# Validate associated data types.
for associated_type in cls.ASSOCIATED_TYPES:
if not inspect.isclass(associated_type):
raise MaterializerInterfaceError(
f"Associated type {associated_type} for materializer "
f"{name} is not a class.",
url="https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types",
)
# Register the materializer.
for associated_type in cls.ASSOCIATED_TYPES:
materializer_registry.register_materializer_type(
associated_type, cls
)
return cls
built_in_materializer
Implementation of ZenML's builtin materializer.
BuiltInContainerMaterializer (BaseMaterializer)
Handle built-in container types (dict, list, set, tuple).
Source code in zenml/materializers/built_in_materializer.py
class BuiltInContainerMaterializer(BaseMaterializer):
"""Handle built-in container types (dict, list, set, tuple)."""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (
dict,
list,
set,
tuple,
)
def __init__(self, uri: str):
"""Define `self.data_path` and `self.metadata_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
self.data_path = os.path.join(self.uri, DEFAULT_FILENAME)
self.metadata_path = os.path.join(self.uri, DEFAULT_METADATA_FILENAME)
def load(self, data_type: Type[Any]) -> Any:
"""Reads a materialized built-in container object.
If the data was serialized to JSON, deserialize it.
Otherwise, reconstruct all elements according to the metadata file:
1. Resolve the data type using `find_type_by_str()`,
2. Get the materializer via the `default_materializer_registry`,
3. Initialize the materializer with the desired path,
4. Use `load()` of that materializer to load the element.
Args:
data_type: The type of the data to read.
Returns:
The data read.
Raises:
RuntimeError: If the data was not found.
"""
artifact_store = Client().active_stack.artifact_store
# If the data was not serialized, there must be metadata present.
if not artifact_store.exists(
self.data_path
) and not artifact_store.exists(self.metadata_path):
raise RuntimeError(
f"Materialization of type {data_type} failed. Expected either"
f"{self.data_path} or {self.metadata_path} to exist."
)
# If the data was serialized as JSON, deserialize it.
if artifact_store.exists(self.data_path):
outputs = yaml_utils.read_json(self.data_path)
# Otherwise, use the metadata to reconstruct the data as a list.
else:
metadata = yaml_utils.read_json(self.metadata_path)
outputs = []
# Backwards compatibility for zenml <= 0.37.0
if isinstance(metadata, dict):
for path_, type_str in zip(
metadata["paths"], metadata["types"]
):
type_ = find_type_by_str(type_str)
materializer_class = materializer_registry[type_]
materializer = materializer_class(uri=path_)
element = materializer.load(type_)
outputs.append(element)
# New format for zenml > 0.37.0
elif isinstance(metadata, list):
for entry in metadata:
path_ = entry["path"]
type_ = source_utils.load(entry["type"])
materializer_class = source_utils.load(
entry["materializer"]
)
materializer = materializer_class(uri=path_)
element = materializer.load(type_)
outputs.append(element)
else:
raise RuntimeError(f"Unknown metadata format: {metadata}.")
# Cast the data to the correct type.
if issubclass(data_type, dict) and not isinstance(outputs, dict):
keys, values = outputs
return dict(zip(keys, values))
if issubclass(data_type, tuple) and not isinstance(outputs, tuple):
return tuple(outputs)
if issubclass(data_type, set) and not isinstance(outputs, set):
return set(outputs)
return outputs
def save(self, data: Any) -> None:
"""Materialize a built-in container object.
If the object can be serialized to JSON, serialize it.
Otherwise, use the `default_materializer_registry` to find the correct
materializer for each element and materialize each element into a
subdirectory.
Tuples and sets are cast to list before materialization.
For non-serializable dicts, materialize keys/values as separate lists.
Args:
data: The built-in container object to materialize.
Raises:
Exception: If any exception occurs, it is raised after cleanup.
"""
artifact_store = Client().active_stack.artifact_store
# tuple and set: handle as list.
if isinstance(data, tuple) or isinstance(data, set):
data = list(data)
# If the data is serializable, just write it into a single JSON file.
if _is_serializable(data):
yaml_utils.write_json(self.data_path, data)
return
# non-serializable dict: Handle as non-serializable list of lists.
if isinstance(data, dict):
data = [list(data.keys()), list(data.values())]
# non-serializable list: Materialize each element into a subfolder.
# Get path, type, and corresponding materializer for each element.
metadata: List[Dict[str, str]] = []
materializers: List[BaseMaterializer] = []
try:
for i, element in enumerate(data):
element_path = os.path.join(self.uri, str(i))
artifact_store.mkdir(element_path)
type_ = type(element)
materializer_class = materializer_registry[type_]
materializer = materializer_class(uri=element_path)
materializers.append(materializer)
metadata.append(
{
"path": element_path,
"type": source_utils.resolve(type_).import_path,
"materializer": source_utils.resolve(
materializer_class
).import_path,
}
)
# Write metadata as JSON.
yaml_utils.write_json(self.metadata_path, metadata)
# Materialize each element.
for element, materializer in zip(data, materializers):
materializer.validate_type_compatibility(type(element))
materializer.save(element)
# If an error occurs, delete all created files.
except Exception as e:
# Delete metadata
if artifact_store.exists(self.metadata_path):
artifact_store.remove(self.metadata_path)
# Delete all elements that were already saved.
for entry in metadata:
artifact_store.rmtree(entry["path"])
raise e
def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given built-in container object.
Args:
data: The built-in container object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
if hasattr(data, "__len__"):
return {"length": len(data)}
return {}
__init__(self, uri)
special
Define self.data_path
and self.metadata_path
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI where the artifact data is stored. |
required |
Source code in zenml/materializers/built_in_materializer.py
def __init__(self, uri: str):
"""Define `self.data_path` and `self.metadata_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
self.data_path = os.path.join(self.uri, DEFAULT_FILENAME)
self.metadata_path = os.path.join(self.uri, DEFAULT_METADATA_FILENAME)
extract_metadata(self, data)
Extract metadata from the given built-in container object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The built-in container object to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/materializers/built_in_materializer.py
def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given built-in container object.
Args:
data: The built-in container object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
if hasattr(data, "__len__"):
return {"length": len(data)}
return {}
load(self, data_type)
Reads a materialized built-in container object.
If the data was serialized to JSON, deserialize it.
Otherwise, reconstruct all elements according to the metadata file:
1. Resolve the data type using find_type_by_str()
,
2. Get the materializer via the default_materializer_registry
,
3. Initialize the materializer with the desired path,
4. Use load()
of that materializer to load the element.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Any |
The data read. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the data was not found. |
Source code in zenml/materializers/built_in_materializer.py
def load(self, data_type: Type[Any]) -> Any:
"""Reads a materialized built-in container object.
If the data was serialized to JSON, deserialize it.
Otherwise, reconstruct all elements according to the metadata file:
1. Resolve the data type using `find_type_by_str()`,
2. Get the materializer via the `default_materializer_registry`,
3. Initialize the materializer with the desired path,
4. Use `load()` of that materializer to load the element.
Args:
data_type: The type of the data to read.
Returns:
The data read.
Raises:
RuntimeError: If the data was not found.
"""
artifact_store = Client().active_stack.artifact_store
# If the data was not serialized, there must be metadata present.
if not artifact_store.exists(
self.data_path
) and not artifact_store.exists(self.metadata_path):
raise RuntimeError(
f"Materialization of type {data_type} failed. Expected either"
f"{self.data_path} or {self.metadata_path} to exist."
)
# If the data was serialized as JSON, deserialize it.
if artifact_store.exists(self.data_path):
outputs = yaml_utils.read_json(self.data_path)
# Otherwise, use the metadata to reconstruct the data as a list.
else:
metadata = yaml_utils.read_json(self.metadata_path)
outputs = []
# Backwards compatibility for zenml <= 0.37.0
if isinstance(metadata, dict):
for path_, type_str in zip(
metadata["paths"], metadata["types"]
):
type_ = find_type_by_str(type_str)
materializer_class = materializer_registry[type_]
materializer = materializer_class(uri=path_)
element = materializer.load(type_)
outputs.append(element)
# New format for zenml > 0.37.0
elif isinstance(metadata, list):
for entry in metadata:
path_ = entry["path"]
type_ = source_utils.load(entry["type"])
materializer_class = source_utils.load(
entry["materializer"]
)
materializer = materializer_class(uri=path_)
element = materializer.load(type_)
outputs.append(element)
else:
raise RuntimeError(f"Unknown metadata format: {metadata}.")
# Cast the data to the correct type.
if issubclass(data_type, dict) and not isinstance(outputs, dict):
keys, values = outputs
return dict(zip(keys, values))
if issubclass(data_type, tuple) and not isinstance(outputs, tuple):
return tuple(outputs)
if issubclass(data_type, set) and not isinstance(outputs, set):
return set(outputs)
return outputs
save(self, data)
Materialize a built-in container object.
If the object can be serialized to JSON, serialize it.
Otherwise, use the default_materializer_registry
to find the correct
materializer for each element and materialize each element into a
subdirectory.
Tuples and sets are cast to list before materialization.
For non-serializable dicts, materialize keys/values as separate lists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The built-in container object to materialize. |
required |
Exceptions:
Type | Description |
---|---|
Exception |
If any exception occurs, it is raised after cleanup. |
Source code in zenml/materializers/built_in_materializer.py
def save(self, data: Any) -> None:
"""Materialize a built-in container object.
If the object can be serialized to JSON, serialize it.
Otherwise, use the `default_materializer_registry` to find the correct
materializer for each element and materialize each element into a
subdirectory.
Tuples and sets are cast to list before materialization.
For non-serializable dicts, materialize keys/values as separate lists.
Args:
data: The built-in container object to materialize.
Raises:
Exception: If any exception occurs, it is raised after cleanup.
"""
artifact_store = Client().active_stack.artifact_store
# tuple and set: handle as list.
if isinstance(data, tuple) or isinstance(data, set):
data = list(data)
# If the data is serializable, just write it into a single JSON file.
if _is_serializable(data):
yaml_utils.write_json(self.data_path, data)
return
# non-serializable dict: Handle as non-serializable list of lists.
if isinstance(data, dict):
data = [list(data.keys()), list(data.values())]
# non-serializable list: Materialize each element into a subfolder.
# Get path, type, and corresponding materializer for each element.
metadata: List[Dict[str, str]] = []
materializers: List[BaseMaterializer] = []
try:
for i, element in enumerate(data):
element_path = os.path.join(self.uri, str(i))
artifact_store.mkdir(element_path)
type_ = type(element)
materializer_class = materializer_registry[type_]
materializer = materializer_class(uri=element_path)
materializers.append(materializer)
metadata.append(
{
"path": element_path,
"type": source_utils.resolve(type_).import_path,
"materializer": source_utils.resolve(
materializer_class
).import_path,
}
)
# Write metadata as JSON.
yaml_utils.write_json(self.metadata_path, metadata)
# Materialize each element.
for element, materializer in zip(data, materializers):
materializer.validate_type_compatibility(type(element))
materializer.save(element)
# If an error occurs, delete all created files.
except Exception as e:
# Delete metadata
if artifact_store.exists(self.metadata_path):
artifact_store.remove(self.metadata_path)
# Delete all elements that were already saved.
for entry in metadata:
artifact_store.rmtree(entry["path"])
raise e
BuiltInMaterializer (BaseMaterializer)
Handle JSON-serializable basic types (bool
, float
, int
, str
).
Source code in zenml/materializers/built_in_materializer.py
class BuiltInMaterializer(BaseMaterializer):
"""Handle JSON-serializable basic types (`bool`, `float`, `int`, `str`)."""
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = BASIC_TYPES
def __init__(self, uri: str):
"""Define `self.data_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
self.data_path = os.path.join(self.uri, DEFAULT_FILENAME)
def load(
self, data_type: Union[Type[bool], Type[float], Type[int], Type[str]]
) -> Any:
"""Reads basic primitive types from JSON.
Args:
data_type: The type of the data to read.
Returns:
The data read.
"""
contents = yaml_utils.read_json(self.data_path)
if type(contents) != data_type:
# TODO [ENG-142]: Raise error or try to coerce
logger.debug(
f"Contents {contents} was type {type(contents)} but expected "
f"{data_type}"
)
return contents
def save(self, data: Union[bool, float, int, str]) -> None:
"""Serialize a basic type to JSON.
Args:
data: The data to store.
"""
yaml_utils.write_json(self.data_path, data)
def extract_metadata(
self, data: Union[bool, float, int, str]
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given built-in container object.
Args:
data: The built-in container object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
# For boolean and numbers, add the string representation as metadata.
# We don't to this for strings because they can be arbitrarily long.
if isinstance(data, (bool, float, int)):
return {"string_representation": str(data)}
return {}
__init__(self, uri)
special
Define self.data_path
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI where the artifact data is stored. |
required |
Source code in zenml/materializers/built_in_materializer.py
def __init__(self, uri: str):
"""Define `self.data_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
self.data_path = os.path.join(self.uri, DEFAULT_FILENAME)
extract_metadata(self, data)
Extract metadata from the given built-in container object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Union[bool, float, int, str] |
The built-in container object to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/materializers/built_in_materializer.py
def extract_metadata(
self, data: Union[bool, float, int, str]
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given built-in container object.
Args:
data: The built-in container object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
# For boolean and numbers, add the string representation as metadata.
# We don't to this for strings because they can be arbitrarily long.
if isinstance(data, (bool, float, int)):
return {"string_representation": str(data)}
return {}
load(self, data_type)
Reads basic primitive types from JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Union[Type[bool], Type[float], Type[int], Type[str]] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Any |
The data read. |
Source code in zenml/materializers/built_in_materializer.py
def load(
self, data_type: Union[Type[bool], Type[float], Type[int], Type[str]]
) -> Any:
"""Reads basic primitive types from JSON.
Args:
data_type: The type of the data to read.
Returns:
The data read.
"""
contents = yaml_utils.read_json(self.data_path)
if type(contents) != data_type:
# TODO [ENG-142]: Raise error or try to coerce
logger.debug(
f"Contents {contents} was type {type(contents)} but expected "
f"{data_type}"
)
return contents
save(self, data)
Serialize a basic type to JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Union[bool, float, int, str] |
The data to store. |
required |
Source code in zenml/materializers/built_in_materializer.py
def save(self, data: Union[bool, float, int, str]) -> None:
"""Serialize a basic type to JSON.
Args:
data: The data to store.
"""
yaml_utils.write_json(self.data_path, data)
BytesMaterializer (BaseMaterializer)
Handle bytes
data type, which is not JSON serializable.
Source code in zenml/materializers/built_in_materializer.py
class BytesMaterializer(BaseMaterializer):
"""Handle `bytes` data type, which is not JSON serializable."""
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (bytes,)
def __init__(self, uri: str):
"""Define `self.data_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
self.data_path = os.path.join(self.uri, DEFAULT_BYTES_FILENAME)
def load(self, data_type: Type[Any]) -> Any:
"""Reads a bytes object from file.
Args:
data_type: The type of the data to read.
Returns:
The data read.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self.data_path, "rb") as file_:
return file_.read()
def save(self, data: Any) -> None:
"""Save a bytes object to file.
Args:
data: The data to store.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self.data_path, "wb") as file_:
file_.write(data)
__init__(self, uri)
special
Define self.data_path
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI where the artifact data is stored. |
required |
Source code in zenml/materializers/built_in_materializer.py
def __init__(self, uri: str):
"""Define `self.data_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
self.data_path = os.path.join(self.uri, DEFAULT_BYTES_FILENAME)
load(self, data_type)
Reads a bytes object from file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Any |
The data read. |
Source code in zenml/materializers/built_in_materializer.py
def load(self, data_type: Type[Any]) -> Any:
"""Reads a bytes object from file.
Args:
data_type: The type of the data to read.
Returns:
The data read.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self.data_path, "rb") as file_:
return file_.read()
save(self, data)
Save a bytes object to file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The data to store. |
required |
Source code in zenml/materializers/built_in_materializer.py
def save(self, data: Any) -> None:
"""Save a bytes object to file.
Args:
data: The data to store.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self.data_path, "wb") as file_:
file_.write(data)
find_materializer_registry_type(type_)
For a given type, find the type registered in the registry.
This can be either the type itself, or a superclass of the type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
type_ |
Type[Any] |
The type to find. |
required |
Returns:
Type | Description |
---|---|
Type[Any] |
The type registered in the registry. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the type could not be resolved. |
Source code in zenml/materializers/built_in_materializer.py
def find_materializer_registry_type(type_: Type[Any]) -> Type[Any]:
"""For a given type, find the type registered in the registry.
This can be either the type itself, or a superclass of the type.
Args:
type_: The type to find.
Returns:
The type registered in the registry.
Raises:
RuntimeError: If the type could not be resolved.
"""
# Check that a unique materializer is registered for this type
materializer_registry[type_]
# Check if the type itself is registered
registered_types = materializer_registry.materializer_types.keys()
if type_ in registered_types:
return type_
# Check if a superclass of the type is registered
for registered_type in registered_types:
if issubclass(type_, registered_type):
return registered_type
# Raise an error otherwise - this should never happen since
# `default_materializer_registry[type_]` should have raised an error already
raise RuntimeError(
f"Cannot find a materializer for type '{type_}' in the "
f"materializer registry."
)
find_type_by_str(type_str)
Get a Python type, given its string representation.
E.g., "int
.
Currently this is implemented by checking all artifact types registered in
the default_materializer_registry
. This means, only types in the registry
can be found. Any other types will cause a RunTimeError
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
type_str |
str |
The string representation of a type. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the type could not be resolved. |
Returns:
Type | Description |
---|---|
Type[Any] |
The type whose string representation is |
Source code in zenml/materializers/built_in_materializer.py
def find_type_by_str(type_str: str) -> Type[Any]:
"""Get a Python type, given its string representation.
E.g., "<class 'int'>" should resolve to `int`.
Currently this is implemented by checking all artifact types registered in
the `default_materializer_registry`. This means, only types in the registry
can be found. Any other types will cause a `RunTimeError`.
Args:
type_str: The string representation of a type.
Raises:
RuntimeError: If the type could not be resolved.
Returns:
The type whose string representation is `type_str`.
"""
registered_types = materializer_registry.materializer_types.keys()
type_str_mapping = {str(type_): type_ for type_ in registered_types}
if type_str in type_str_mapping:
return type_str_mapping[type_str]
raise RuntimeError(f"Cannot resolve type '{type_str}'.")
cloudpickle_materializer
Implementation of ZenML's cloudpickle materializer.
CloudpickleMaterializer (BaseMaterializer)
Materializer using cloudpickle.
This materializer can materialize (almost) any object, but does so in a non-reproducble way since artifacts cannot be loaded from other Python versions. It is recommended to use this materializer only as a last resort.
That is also why it has SKIP_REGISTRATION
set to True and is currently
only used as a fallback materializer inside the materializer registry.
Source code in zenml/materializers/cloudpickle_materializer.py
class CloudpickleMaterializer(BaseMaterializer):
"""Materializer using cloudpickle.
This materializer can materialize (almost) any object, but does so in a
non-reproducble way since artifacts cannot be loaded from other Python
versions. It is recommended to use this materializer only as a last resort.
That is also why it has `SKIP_REGISTRATION` set to True and is currently
only used as a fallback materializer inside the materializer registry.
"""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (object,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
SKIP_REGISTRATION: ClassVar[bool] = True
def load(self, data_type: Type[Any]) -> Any:
"""Reads an artifact from a cloudpickle file.
Args:
data_type: The data type of the artifact.
Returns:
The loaded artifact data.
"""
# validate python version
artifact_store = Client().active_stack.artifact_store
source_python_version = self._load_python_version()
current_python_version = Environment().python_version()
if source_python_version != current_python_version:
logger.warning(
f"Your artifact was materialized under Python version "
f"'{source_python_version}' but you are currently using "
f"'{current_python_version}'. This might cause unexpected "
"behavior since pickle is not reproducible across Python "
"versions. Attempting to load anyway..."
)
# load data
filepath = os.path.join(self.uri, DEFAULT_FILENAME)
with artifact_store.open(filepath, "rb") as fid:
data = cloudpickle.load(fid)
return data
def _load_python_version(self) -> str:
"""Loads the Python version that was used to materialize the artifact.
Returns:
The Python version that was used to materialize the artifact.
"""
filepath = os.path.join(self.uri, DEFAULT_PYTHON_VERSION_FILENAME)
if os.path.exists(filepath):
return read_file_contents_as_string(filepath)
return "unknown"
def save(self, data: Any) -> None:
"""Saves an artifact to a cloudpickle file.
Args:
data: The data to save.
"""
artifact_store = Client().active_stack.artifact_store
# Log a warning if this materializer was not explicitly specified for
# the given data type.
if type(self) == CloudpickleMaterializer:
logger.warning(
f"No materializer is registered for type `{type(data)}`, so "
"the default Pickle materializer was used. Pickle is not "
"production ready and should only be used for prototyping as "
"the artifacts cannot be loaded when running with a different "
"Python version. Please consider implementing a custom "
f"materializer for type `{type(data)}` according to the "
"instructions at https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types"
)
# save python version for validation on loading
self._save_python_version()
# save data
filepath = os.path.join(self.uri, DEFAULT_FILENAME)
with artifact_store.open(filepath, "wb") as fid:
cloudpickle.dump(data, fid)
def _save_python_version(self) -> None:
"""Saves the Python version used to materialize the artifact."""
filepath = os.path.join(self.uri, DEFAULT_PYTHON_VERSION_FILENAME)
current_python_version = Environment().python_version()
write_file_contents_as_string(filepath, current_python_version)
load(self, data_type)
Reads an artifact from a cloudpickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The data type of the artifact. |
required |
Returns:
Type | Description |
---|---|
Any |
The loaded artifact data. |
Source code in zenml/materializers/cloudpickle_materializer.py
def load(self, data_type: Type[Any]) -> Any:
"""Reads an artifact from a cloudpickle file.
Args:
data_type: The data type of the artifact.
Returns:
The loaded artifact data.
"""
# validate python version
artifact_store = Client().active_stack.artifact_store
source_python_version = self._load_python_version()
current_python_version = Environment().python_version()
if source_python_version != current_python_version:
logger.warning(
f"Your artifact was materialized under Python version "
f"'{source_python_version}' but you are currently using "
f"'{current_python_version}'. This might cause unexpected "
"behavior since pickle is not reproducible across Python "
"versions. Attempting to load anyway..."
)
# load data
filepath = os.path.join(self.uri, DEFAULT_FILENAME)
with artifact_store.open(filepath, "rb") as fid:
data = cloudpickle.load(fid)
return data
save(self, data)
Saves an artifact to a cloudpickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
The data to save. |
required |
Source code in zenml/materializers/cloudpickle_materializer.py
def save(self, data: Any) -> None:
"""Saves an artifact to a cloudpickle file.
Args:
data: The data to save.
"""
artifact_store = Client().active_stack.artifact_store
# Log a warning if this materializer was not explicitly specified for
# the given data type.
if type(self) == CloudpickleMaterializer:
logger.warning(
f"No materializer is registered for type `{type(data)}`, so "
"the default Pickle materializer was used. Pickle is not "
"production ready and should only be used for prototyping as "
"the artifacts cannot be loaded when running with a different "
"Python version. Please consider implementing a custom "
f"materializer for type `{type(data)}` according to the "
"instructions at https://docs.zenml.io/user-guide/advanced-guide/artifact-management/handle-custom-data-types"
)
# save python version for validation on loading
self._save_python_version()
# save data
filepath = os.path.join(self.uri, DEFAULT_FILENAME)
with artifact_store.open(filepath, "wb") as fid:
cloudpickle.dump(data, fid)
materializer_registry
Implementation of a default materializer registry.
MaterializerRegistry
Matches a Python type to a default materializer.
Source code in zenml/materializers/materializer_registry.py
class MaterializerRegistry:
"""Matches a Python type to a default materializer."""
def __init__(self) -> None:
"""Initialize the materializer registry."""
self.default_materializer: Optional[Type["BaseMaterializer"]] = None
self.materializer_types: Dict[Type[Any], Type["BaseMaterializer"]] = {}
def register_materializer_type(
self, key: Type[Any], type_: Type["BaseMaterializer"]
) -> None:
"""Registers a new materializer.
Args:
key: Indicates the type of object.
type_: A BaseMaterializer subclass.
"""
if key not in self.materializer_types:
self.materializer_types[key] = type_
logger.debug(f"Registered materializer {type_} for {key}")
else:
logger.debug(
f"Found existing materializer class for {key}: "
f"{self.materializer_types[key]}. Skipping registration of "
f"{type_}."
)
def register_and_overwrite_type(
self, key: Type[Any], type_: Type["BaseMaterializer"]
) -> None:
"""Registers a new materializer and also overwrites a default if set.
Args:
key: Indicates the type of object.
type_: A BaseMaterializer subclass.
"""
self.materializer_types[key] = type_
logger.debug(f"Registered materializer {type_} for {key}")
def __getitem__(self, key: Type[Any]) -> Type["BaseMaterializer"]:
"""Get a single materializers based on the key.
Args:
key: Indicates the type of object.
Returns:
`BaseMaterializer` subclass that was registered for this key.
"""
for class_ in key.__mro__:
materializer = self.materializer_types.get(class_, None)
if materializer:
return materializer
return self.get_default_materializer()
def get_default_materializer(self) -> Type["BaseMaterializer"]:
"""Get the default materializer that is used if no other is found.
Returns:
The default materializer.
"""
from zenml.materializers.cloudpickle_materializer import (
CloudpickleMaterializer,
)
if self.default_materializer:
return self.default_materializer
return CloudpickleMaterializer
def get_materializer_types(
self,
) -> Dict[Type[Any], Type["BaseMaterializer"]]:
"""Get all registered materializer types.
Returns:
A dictionary of registered materializer types.
"""
return self.materializer_types
def is_registered(self, key: Type[Any]) -> bool:
"""Returns if a materializer class is registered for the given type.
Args:
key: Indicates the type of object.
Returns:
True if a materializer is registered for the given type, False
otherwise.
"""
return any(issubclass(key, type_) for type_ in self.materializer_types)
__getitem__(self, key)
special
Get a single materializers based on the key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Type[Any] |
Indicates the type of object. |
required |
Returns:
Type | Description |
---|---|
Type[BaseMaterializer] |
|
Source code in zenml/materializers/materializer_registry.py
def __getitem__(self, key: Type[Any]) -> Type["BaseMaterializer"]:
"""Get a single materializers based on the key.
Args:
key: Indicates the type of object.
Returns:
`BaseMaterializer` subclass that was registered for this key.
"""
for class_ in key.__mro__:
materializer = self.materializer_types.get(class_, None)
if materializer:
return materializer
return self.get_default_materializer()
__init__(self)
special
Initialize the materializer registry.
Source code in zenml/materializers/materializer_registry.py
def __init__(self) -> None:
"""Initialize the materializer registry."""
self.default_materializer: Optional[Type["BaseMaterializer"]] = None
self.materializer_types: Dict[Type[Any], Type["BaseMaterializer"]] = {}
get_default_materializer(self)
Get the default materializer that is used if no other is found.
Returns:
Type | Description |
---|---|
Type[BaseMaterializer] |
The default materializer. |
Source code in zenml/materializers/materializer_registry.py
def get_default_materializer(self) -> Type["BaseMaterializer"]:
"""Get the default materializer that is used if no other is found.
Returns:
The default materializer.
"""
from zenml.materializers.cloudpickle_materializer import (
CloudpickleMaterializer,
)
if self.default_materializer:
return self.default_materializer
return CloudpickleMaterializer
get_materializer_types(self)
Get all registered materializer types.
Returns:
Type | Description |
---|---|
Dict[Type[Any], Type[BaseMaterializer]] |
A dictionary of registered materializer types. |
Source code in zenml/materializers/materializer_registry.py
def get_materializer_types(
self,
) -> Dict[Type[Any], Type["BaseMaterializer"]]:
"""Get all registered materializer types.
Returns:
A dictionary of registered materializer types.
"""
return self.materializer_types
is_registered(self, key)
Returns if a materializer class is registered for the given type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Type[Any] |
Indicates the type of object. |
required |
Returns:
Type | Description |
---|---|
bool |
True if a materializer is registered for the given type, False otherwise. |
Source code in zenml/materializers/materializer_registry.py
def is_registered(self, key: Type[Any]) -> bool:
"""Returns if a materializer class is registered for the given type.
Args:
key: Indicates the type of object.
Returns:
True if a materializer is registered for the given type, False
otherwise.
"""
return any(issubclass(key, type_) for type_ in self.materializer_types)
register_and_overwrite_type(self, key, type_)
Registers a new materializer and also overwrites a default if set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Type[Any] |
Indicates the type of object. |
required |
type_ |
Type[BaseMaterializer] |
A BaseMaterializer subclass. |
required |
Source code in zenml/materializers/materializer_registry.py
def register_and_overwrite_type(
self, key: Type[Any], type_: Type["BaseMaterializer"]
) -> None:
"""Registers a new materializer and also overwrites a default if set.
Args:
key: Indicates the type of object.
type_: A BaseMaterializer subclass.
"""
self.materializer_types[key] = type_
logger.debug(f"Registered materializer {type_} for {key}")
register_materializer_type(self, key, type_)
Registers a new materializer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Type[Any] |
Indicates the type of object. |
required |
type_ |
Type[BaseMaterializer] |
A BaseMaterializer subclass. |
required |
Source code in zenml/materializers/materializer_registry.py
def register_materializer_type(
self, key: Type[Any], type_: Type["BaseMaterializer"]
) -> None:
"""Registers a new materializer.
Args:
key: Indicates the type of object.
type_: A BaseMaterializer subclass.
"""
if key not in self.materializer_types:
self.materializer_types[key] = type_
logger.debug(f"Registered materializer {type_} for {key}")
else:
logger.debug(
f"Found existing materializer class for {key}: "
f"{self.materializer_types[key]}. Skipping registration of "
f"{type_}."
)
numpy_materializer
Implementation of the ZenML NumPy materializer.
NumpyMaterializer (BaseMaterializer)
Materializer to read data to and from pandas.
Source code in zenml/materializers/numpy_materializer.py
class NumpyMaterializer(BaseMaterializer):
"""Materializer to read data to and from pandas."""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (np.ndarray,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
def load(self, data_type: Type[Any]) -> "Any":
"""Reads a numpy array from a `.npy` file.
Args:
data_type: The type of the data to read.
Raises:
ImportError: If pyarrow is not installed.
Returns:
The numpy array.
"""
artifact_store = Client().active_stack.artifact_store
numpy_file = os.path.join(self.uri, NUMPY_FILENAME)
if artifact_store.exists(numpy_file):
with artifact_store.open(numpy_file, "rb") as f:
return np.load(f, allow_pickle=True)
elif artifact_store.exists(os.path.join(self.uri, DATA_FILENAME)):
logger.warning(
"A legacy artifact was found. "
"This artifact was created with an older version of "
"ZenML. You can still use it, but it will be "
"converted to the new format on the next materialization."
)
try:
# Import old materializer dependencies
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
from zenml.utils import yaml_utils
# Read numpy array from parquet file
shape_dict = yaml_utils.read_json(
os.path.join(self.uri, SHAPE_FILENAME)
)
shape_tuple = tuple(shape_dict.values())
with artifact_store.open(
os.path.join(self.uri, DATA_FILENAME), "rb"
) as f:
input_stream = pa.input_stream(f)
data = pq.read_table(input_stream)
vals = getattr(data.to_pandas(), DATA_VAR).values
return np.reshape(vals, shape_tuple)
except ImportError:
raise ImportError(
"You have an old version of a `NumpyMaterializer` ",
"data artifact stored in the artifact store ",
"as a `.parquet` file, which requires `pyarrow` for reading. ",
"You can install `pyarrow` by running `pip install pyarrow`.",
)
def save(self, arr: "NDArray[Any]") -> None:
"""Writes a np.ndarray to the artifact store as a `.npy` file.
Args:
arr: The numpy array to write.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(
os.path.join(self.uri, NUMPY_FILENAME), "wb"
) as f:
np.save(f, arr)
def save_visualizations(
self, arr: "NDArray[Any]"
) -> Dict[str, VisualizationType]:
"""Saves visualizations for a numpy array.
If the array is 1D, a histogram is saved. If the array is 2D or 3D with
3 or 4 channels, an image is saved.
Args:
arr: The numpy array to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
if not np.issubdtype(arr.dtype, np.number):
return {}
try:
# Save histogram for 1D arrays
if len(arr.shape) == 1:
histogram_path = os.path.join(self.uri, "histogram.png")
histogram_path = histogram_path.replace("\\", "/")
self._save_histogram(histogram_path, arr)
return {histogram_path: VisualizationType.IMAGE}
# Save as image for 3D arrays with 3 or 4 channels
if len(arr.shape) == 3 and arr.shape[2] in [3, 4]:
image_path = os.path.join(self.uri, "image.png")
image_path = image_path.replace("\\", "/")
self._save_image(image_path, arr)
return {image_path: VisualizationType.IMAGE}
except ImportError:
logger.info(
"Skipping visualization of numpy array because matplotlib "
"is not installed. To install matplotlib, run "
"`pip install matplotlib`."
)
return {}
def _save_histogram(self, output_path: str, arr: "NDArray[Any]") -> None:
"""Saves a histogram of a numpy array.
Args:
output_path: The path to save the histogram to.
arr: The numpy array of which to save the histogram.
"""
import matplotlib.pyplot as plt
artifact_store = Client().active_stack.artifact_store
plt.hist(arr)
with artifact_store.open(output_path, "wb") as f:
plt.savefig(f)
plt.close()
def _save_image(self, output_path: str, arr: "NDArray[Any]") -> None:
"""Saves a numpy array as an image.
Args:
output_path: The path to save the image to.
arr: The numpy array to save.
"""
from matplotlib.image import imsave
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(output_path, "wb") as f:
imsave(f, arr)
def extract_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
if np.issubdtype(arr.dtype, np.number):
return self._extract_numeric_metadata(arr)
elif np.issubdtype(arr.dtype, np.unicode_) or np.issubdtype(
arr.dtype, np.object_
):
return self._extract_text_metadata(arr)
else:
return {}
def _extract_numeric_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extracts numeric metadata from a numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
A dictionary of metadata.
"""
min_val = np.min(arr).item()
max_val = np.max(arr).item()
numpy_metadata: Dict[str, "MetadataType"] = {
"shape": tuple(arr.shape),
"dtype": DType(arr.dtype.type),
"mean": np.mean(arr).item(),
"std": np.std(arr).item(),
"min": min_val,
"max": max_val,
}
return numpy_metadata
def _extract_text_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extracts text metadata from a numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
A dictionary of metadata.
"""
text = " ".join(arr)
words = text.split()
word_counts = Counter(words)
unique_words = len(word_counts)
total_words = len(words)
most_common_word, most_common_count = word_counts.most_common(1)[0]
text_metadata: Dict[str, "MetadataType"] = {
"shape": tuple(arr.shape),
"dtype": DType(arr.dtype.type),
"unique_words": unique_words,
"total_words": total_words,
"most_common_word": most_common_word,
"most_common_count": most_common_count,
}
return text_metadata
extract_metadata(self, arr)
Extract metadata from the given numpy array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr |
NDArray[Any] |
The numpy array to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/materializers/numpy_materializer.py
def extract_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
if np.issubdtype(arr.dtype, np.number):
return self._extract_numeric_metadata(arr)
elif np.issubdtype(arr.dtype, np.unicode_) or np.issubdtype(
arr.dtype, np.object_
):
return self._extract_text_metadata(arr)
else:
return {}
load(self, data_type)
Reads a numpy array from a .npy
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Exceptions:
Type | Description |
---|---|
ImportError |
If pyarrow is not installed. |
Returns:
Type | Description |
---|---|
Any |
The numpy array. |
Source code in zenml/materializers/numpy_materializer.py
def load(self, data_type: Type[Any]) -> "Any":
"""Reads a numpy array from a `.npy` file.
Args:
data_type: The type of the data to read.
Raises:
ImportError: If pyarrow is not installed.
Returns:
The numpy array.
"""
artifact_store = Client().active_stack.artifact_store
numpy_file = os.path.join(self.uri, NUMPY_FILENAME)
if artifact_store.exists(numpy_file):
with artifact_store.open(numpy_file, "rb") as f:
return np.load(f, allow_pickle=True)
elif artifact_store.exists(os.path.join(self.uri, DATA_FILENAME)):
logger.warning(
"A legacy artifact was found. "
"This artifact was created with an older version of "
"ZenML. You can still use it, but it will be "
"converted to the new format on the next materialization."
)
try:
# Import old materializer dependencies
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
from zenml.utils import yaml_utils
# Read numpy array from parquet file
shape_dict = yaml_utils.read_json(
os.path.join(self.uri, SHAPE_FILENAME)
)
shape_tuple = tuple(shape_dict.values())
with artifact_store.open(
os.path.join(self.uri, DATA_FILENAME), "rb"
) as f:
input_stream = pa.input_stream(f)
data = pq.read_table(input_stream)
vals = getattr(data.to_pandas(), DATA_VAR).values
return np.reshape(vals, shape_tuple)
except ImportError:
raise ImportError(
"You have an old version of a `NumpyMaterializer` ",
"data artifact stored in the artifact store ",
"as a `.parquet` file, which requires `pyarrow` for reading. ",
"You can install `pyarrow` by running `pip install pyarrow`.",
)
save(self, arr)
Writes a np.ndarray to the artifact store as a .npy
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr |
NDArray[Any] |
The numpy array to write. |
required |
Source code in zenml/materializers/numpy_materializer.py
def save(self, arr: "NDArray[Any]") -> None:
"""Writes a np.ndarray to the artifact store as a `.npy` file.
Args:
arr: The numpy array to write.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(
os.path.join(self.uri, NUMPY_FILENAME), "wb"
) as f:
np.save(f, arr)
save_visualizations(self, arr)
Saves visualizations for a numpy array.
If the array is 1D, a histogram is saved. If the array is 2D or 3D with 3 or 4 channels, an image is saved.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr |
NDArray[Any] |
The numpy array to visualize. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.enums.VisualizationType] |
A dictionary of visualization URIs and their types. |
Source code in zenml/materializers/numpy_materializer.py
def save_visualizations(
self, arr: "NDArray[Any]"
) -> Dict[str, VisualizationType]:
"""Saves visualizations for a numpy array.
If the array is 1D, a histogram is saved. If the array is 2D or 3D with
3 or 4 channels, an image is saved.
Args:
arr: The numpy array to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
if not np.issubdtype(arr.dtype, np.number):
return {}
try:
# Save histogram for 1D arrays
if len(arr.shape) == 1:
histogram_path = os.path.join(self.uri, "histogram.png")
histogram_path = histogram_path.replace("\\", "/")
self._save_histogram(histogram_path, arr)
return {histogram_path: VisualizationType.IMAGE}
# Save as image for 3D arrays with 3 or 4 channels
if len(arr.shape) == 3 and arr.shape[2] in [3, 4]:
image_path = os.path.join(self.uri, "image.png")
image_path = image_path.replace("\\", "/")
self._save_image(image_path, arr)
return {image_path: VisualizationType.IMAGE}
except ImportError:
logger.info(
"Skipping visualization of numpy array because matplotlib "
"is not installed. To install matplotlib, run "
"`pip install matplotlib`."
)
return {}
pandas_materializer
Materializer for Pandas.
PandasMaterializer (BaseMaterializer)
Materializer to read data to and from pandas.
Source code in zenml/materializers/pandas_materializer.py
class PandasMaterializer(BaseMaterializer):
"""Materializer to read data to and from pandas."""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (
pd.DataFrame,
pd.Series,
)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
def __init__(self, uri: str):
"""Define `self.data_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
try:
import pyarrow # type: ignore # noqa
self.pyarrow_exists = True
except ImportError:
self.pyarrow_exists = False
logger.warning(
"By default, the `PandasMaterializer` stores data as a "
"`.csv` file. If you want to store data more efficiently, "
"you can install `pyarrow` by running "
"'`pip install pyarrow`'. This will allow `PandasMaterializer` "
"to automatically store the data as a `.parquet` file instead."
)
finally:
self.parquet_path = os.path.join(self.uri, PARQUET_FILENAME)
self.csv_path = os.path.join(self.uri, CSV_FILENAME)
def load(self, data_type: Type[Any]) -> Union[pd.DataFrame, pd.Series]:
"""Reads `pd.DataFrame` or `pd.Series` from a `.parquet` or `.csv` file.
Args:
data_type: The type of the data to read.
Raises:
ImportError: If pyarrow or fastparquet is not installed.
Returns:
The pandas dataframe or series.
"""
artifact_store = Client().active_stack.artifact_store
if artifact_store.exists(self.parquet_path):
if self.pyarrow_exists:
with artifact_store.open(self.parquet_path, mode="rb") as f:
df = pd.read_parquet(f)
else:
raise ImportError(
"You have an old version of a `PandasMaterializer` "
"data artifact stored in the artifact store "
"as a `.parquet` file, which requires `pyarrow` "
"for reading, You can install `pyarrow` by running "
"'`pip install pyarrow fastparquet`'."
)
else:
with artifact_store.open(self.csv_path, mode="rb") as f:
df = pd.read_csv(f, index_col=0, parse_dates=True)
# validate the type of the data.
def is_dataframe_or_series(
df: Union[pd.DataFrame, pd.Series],
) -> Union[pd.DataFrame, pd.Series]:
"""Checks if the data is a `pd.DataFrame` or `pd.Series`.
Args:
df: The data to check.
Returns:
The data if it is a `pd.DataFrame` or `pd.Series`.
"""
if issubclass(data_type, pd.Series):
# Taking the first column if its a series as the assumption
# is that there will only be one
assert len(df.columns) == 1
df = df[df.columns[0]]
return df
else:
return df
return is_dataframe_or_series(df)
def save(self, df: Union[pd.DataFrame, pd.Series]) -> None:
"""Writes a pandas dataframe or series to the specified filename.
Args:
df: The pandas dataframe or series to write.
"""
artifact_store = Client().active_stack.artifact_store
if isinstance(df, pd.Series):
df = df.to_frame(name="series")
if self.pyarrow_exists:
with artifact_store.open(self.parquet_path, mode="wb") as f:
df.to_parquet(f, compression=COMPRESSION_TYPE)
else:
with artifact_store.open(self.csv_path, mode="wb") as f:
df.to_csv(f, index=True)
def save_visualizations(
self, df: Union[pd.DataFrame, pd.Series]
) -> Dict[str, VisualizationType]:
"""Save visualizations of the given pandas dataframe or series.
Args:
df: The pandas dataframe or series to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
artifact_store = Client().active_stack.artifact_store
describe_uri = os.path.join(self.uri, "describe.csv")
describe_uri = describe_uri.replace("\\", "/")
with artifact_store.open(describe_uri, mode="wb") as f:
df.describe().to_csv(f)
return {describe_uri: VisualizationType.CSV}
def extract_metadata(
self, df: Union[pd.DataFrame, pd.Series]
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given pandas dataframe or series.
Args:
df: The pandas dataframe or series to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
pandas_metadata: Dict[str, "MetadataType"] = {"shape": df.shape}
if isinstance(df, pd.Series):
pandas_metadata["dtype"] = DType(df.dtype.type)
pandas_metadata["mean"] = float(df.mean().item())
pandas_metadata["std"] = float(df.std().item())
pandas_metadata["min"] = float(df.min().item())
pandas_metadata["max"] = float(df.max().item())
else:
pandas_metadata["dtype"] = {
str(key): DType(value.type) for key, value in df.dtypes.items()
}
for stat_name, stat in {
"mean": df.mean,
"std": df.std,
"min": df.min,
"max": df.max,
}.items():
pandas_metadata[stat_name] = {
str(key): float(value)
for key, value in stat(numeric_only=True).to_dict().items()
}
return pandas_metadata
__init__(self, uri)
special
Define self.data_path
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI where the artifact data is stored. |
required |
Source code in zenml/materializers/pandas_materializer.py
def __init__(self, uri: str):
"""Define `self.data_path`.
Args:
uri: The URI where the artifact data is stored.
"""
super().__init__(uri)
try:
import pyarrow # type: ignore # noqa
self.pyarrow_exists = True
except ImportError:
self.pyarrow_exists = False
logger.warning(
"By default, the `PandasMaterializer` stores data as a "
"`.csv` file. If you want to store data more efficiently, "
"you can install `pyarrow` by running "
"'`pip install pyarrow`'. This will allow `PandasMaterializer` "
"to automatically store the data as a `.parquet` file instead."
)
finally:
self.parquet_path = os.path.join(self.uri, PARQUET_FILENAME)
self.csv_path = os.path.join(self.uri, CSV_FILENAME)
extract_metadata(self, df)
Extract metadata from the given pandas dataframe or series.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
Union[pandas.core.frame.DataFrame, pandas.core.series.Series] |
The pandas dataframe or series to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/materializers/pandas_materializer.py
def extract_metadata(
self, df: Union[pd.DataFrame, pd.Series]
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given pandas dataframe or series.
Args:
df: The pandas dataframe or series to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
pandas_metadata: Dict[str, "MetadataType"] = {"shape": df.shape}
if isinstance(df, pd.Series):
pandas_metadata["dtype"] = DType(df.dtype.type)
pandas_metadata["mean"] = float(df.mean().item())
pandas_metadata["std"] = float(df.std().item())
pandas_metadata["min"] = float(df.min().item())
pandas_metadata["max"] = float(df.max().item())
else:
pandas_metadata["dtype"] = {
str(key): DType(value.type) for key, value in df.dtypes.items()
}
for stat_name, stat in {
"mean": df.mean,
"std": df.std,
"min": df.min,
"max": df.max,
}.items():
pandas_metadata[stat_name] = {
str(key): float(value)
for key, value in stat(numeric_only=True).to_dict().items()
}
return pandas_metadata
load(self, data_type)
Reads pd.DataFrame
or pd.Series
from a .parquet
or .csv
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Exceptions:
Type | Description |
---|---|
ImportError |
If pyarrow or fastparquet is not installed. |
Returns:
Type | Description |
---|---|
Union[pandas.core.frame.DataFrame, pandas.core.series.Series] |
The pandas dataframe or series. |
Source code in zenml/materializers/pandas_materializer.py
def load(self, data_type: Type[Any]) -> Union[pd.DataFrame, pd.Series]:
"""Reads `pd.DataFrame` or `pd.Series` from a `.parquet` or `.csv` file.
Args:
data_type: The type of the data to read.
Raises:
ImportError: If pyarrow or fastparquet is not installed.
Returns:
The pandas dataframe or series.
"""
artifact_store = Client().active_stack.artifact_store
if artifact_store.exists(self.parquet_path):
if self.pyarrow_exists:
with artifact_store.open(self.parquet_path, mode="rb") as f:
df = pd.read_parquet(f)
else:
raise ImportError(
"You have an old version of a `PandasMaterializer` "
"data artifact stored in the artifact store "
"as a `.parquet` file, which requires `pyarrow` "
"for reading, You can install `pyarrow` by running "
"'`pip install pyarrow fastparquet`'."
)
else:
with artifact_store.open(self.csv_path, mode="rb") as f:
df = pd.read_csv(f, index_col=0, parse_dates=True)
# validate the type of the data.
def is_dataframe_or_series(
df: Union[pd.DataFrame, pd.Series],
) -> Union[pd.DataFrame, pd.Series]:
"""Checks if the data is a `pd.DataFrame` or `pd.Series`.
Args:
df: The data to check.
Returns:
The data if it is a `pd.DataFrame` or `pd.Series`.
"""
if issubclass(data_type, pd.Series):
# Taking the first column if its a series as the assumption
# is that there will only be one
assert len(df.columns) == 1
df = df[df.columns[0]]
return df
else:
return df
return is_dataframe_or_series(df)
save(self, df)
Writes a pandas dataframe or series to the specified filename.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
Union[pandas.core.frame.DataFrame, pandas.core.series.Series] |
The pandas dataframe or series to write. |
required |
Source code in zenml/materializers/pandas_materializer.py
def save(self, df: Union[pd.DataFrame, pd.Series]) -> None:
"""Writes a pandas dataframe or series to the specified filename.
Args:
df: The pandas dataframe or series to write.
"""
artifact_store = Client().active_stack.artifact_store
if isinstance(df, pd.Series):
df = df.to_frame(name="series")
if self.pyarrow_exists:
with artifact_store.open(self.parquet_path, mode="wb") as f:
df.to_parquet(f, compression=COMPRESSION_TYPE)
else:
with artifact_store.open(self.csv_path, mode="wb") as f:
df.to_csv(f, index=True)
save_visualizations(self, df)
Save visualizations of the given pandas dataframe or series.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
Union[pandas.core.frame.DataFrame, pandas.core.series.Series] |
The pandas dataframe or series to visualize. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.enums.VisualizationType] |
A dictionary of visualization URIs and their types. |
Source code in zenml/materializers/pandas_materializer.py
def save_visualizations(
self, df: Union[pd.DataFrame, pd.Series]
) -> Dict[str, VisualizationType]:
"""Save visualizations of the given pandas dataframe or series.
Args:
df: The pandas dataframe or series to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
artifact_store = Client().active_stack.artifact_store
describe_uri = os.path.join(self.uri, "describe.csv")
describe_uri = describe_uri.replace("\\", "/")
with artifact_store.open(describe_uri, mode="wb") as f:
df.describe().to_csv(f)
return {describe_uri: VisualizationType.CSV}
pydantic_materializer
Implementation of ZenML's pydantic materializer.
PydanticMaterializer (BaseMaterializer)
Handle Pydantic BaseModel objects.
Source code in zenml/materializers/pydantic_materializer.py
class PydanticMaterializer(BaseMaterializer):
"""Handle Pydantic BaseModel objects."""
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (BaseModel,)
def load(self, data_type: Type[BaseModel]) -> Any:
"""Reads BaseModel from JSON.
Args:
data_type: The type of the data to read.
Returns:
The data read.
"""
data_path = os.path.join(self.uri, DEFAULT_FILENAME)
contents = yaml_utils.read_json(data_path)
return data_type.parse_raw(contents)
def save(self, data: BaseModel) -> None:
"""Serialize a BaseModel to JSON.
Args:
data: The data to store.
"""
data_path = os.path.join(self.uri, DEFAULT_FILENAME)
yaml_utils.write_json(data_path, data.json())
def extract_metadata(self, data: BaseModel) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given BaseModel object.
Args:
data: The BaseModel object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
return {"schema": data.schema()}
extract_metadata(self, data)
Extract metadata from the given BaseModel object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
BaseModel |
The BaseModel object to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/materializers/pydantic_materializer.py
def extract_metadata(self, data: BaseModel) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given BaseModel object.
Args:
data: The BaseModel object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
return {"schema": data.schema()}
load(self, data_type)
Reads BaseModel from JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[pydantic.main.BaseModel] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Any |
The data read. |
Source code in zenml/materializers/pydantic_materializer.py
def load(self, data_type: Type[BaseModel]) -> Any:
"""Reads BaseModel from JSON.
Args:
data_type: The type of the data to read.
Returns:
The data read.
"""
data_path = os.path.join(self.uri, DEFAULT_FILENAME)
contents = yaml_utils.read_json(data_path)
return data_type.parse_raw(contents)
save(self, data)
Serialize a BaseModel to JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
BaseModel |
The data to store. |
required |
Source code in zenml/materializers/pydantic_materializer.py
def save(self, data: BaseModel) -> None:
"""Serialize a BaseModel to JSON.
Args:
data: The data to store.
"""
data_path = os.path.join(self.uri, DEFAULT_FILENAME)
yaml_utils.write_json(data_path, data.json())
service_materializer
Implementation of a materializer to read and write ZenML service instances.
ServiceMaterializer (BaseMaterializer)
Materializer to read/write service instances.
Source code in zenml/materializers/service_materializer.py
class ServiceMaterializer(BaseMaterializer):
"""Materializer to read/write service instances."""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (BaseService,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.SERVICE
def load(self, data_type: Type[Any]) -> BaseService:
"""Creates and returns a service.
This service is instantiated from the serialized service configuration
and last known status information saved as artifact.
Args:
data_type: The type of the data to read.
Returns:
A ZenML service instance.
"""
artifact_store = Client().active_stack.artifact_store
filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME)
with artifact_store.open(filepath, "r") as f:
service = ServiceRegistry().load_service_from_json(f.read())
return service
def save(self, service: BaseService) -> None:
"""Writes a ZenML service.
The configuration and last known status of the input service instance
are serialized and saved as an artifact.
Args:
service: A ZenML service instance.
"""
artifact_store = Client().active_stack.artifact_store
filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME)
with artifact_store.open(filepath, "w") as f:
f.write(service.json(indent=4))
def extract_metadata(
self, service: BaseService
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given service.
Args:
service: The service to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
from zenml.metadata.metadata_types import Uri
if service.endpoint and service.endpoint.status.uri:
return {"uri": Uri(service.endpoint.status.uri)}
return {}
extract_metadata(self, service)
Extract metadata from the given service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/materializers/service_materializer.py
def extract_metadata(
self, service: BaseService
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given service.
Args:
service: The service to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
from zenml.metadata.metadata_types import Uri
if service.endpoint and service.endpoint.status.uri:
return {"uri": Uri(service.endpoint.status.uri)}
return {}
load(self, data_type)
Creates and returns a service.
This service is instantiated from the serialized service configuration and last known status information saved as artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
BaseService |
A ZenML service instance. |
Source code in zenml/materializers/service_materializer.py
def load(self, data_type: Type[Any]) -> BaseService:
"""Creates and returns a service.
This service is instantiated from the serialized service configuration
and last known status information saved as artifact.
Args:
data_type: The type of the data to read.
Returns:
A ZenML service instance.
"""
artifact_store = Client().active_stack.artifact_store
filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME)
with artifact_store.open(filepath, "r") as f:
service = ServiceRegistry().load_service_from_json(f.read())
return service
save(self, service)
Writes a ZenML service.
The configuration and last known status of the input service instance are serialized and saved as an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
A ZenML service instance. |
required |
Source code in zenml/materializers/service_materializer.py
def save(self, service: BaseService) -> None:
"""Writes a ZenML service.
The configuration and last known status of the input service instance
are serialized and saved as an artifact.
Args:
service: A ZenML service instance.
"""
artifact_store = Client().active_stack.artifact_store
filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME)
with artifact_store.open(filepath, "w") as f:
f.write(service.json(indent=4))
structured_string_materializer
Implementation of HTMLString materializer.
StructuredStringMaterializer (BaseMaterializer)
Materializer for HTML or Markdown strings.
Source code in zenml/materializers/structured_string_materializer.py
class StructuredStringMaterializer(BaseMaterializer):
"""Materializer for HTML or Markdown strings."""
ASSOCIATED_TYPES = (CSVString, HTMLString, MarkdownString)
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA_ANALYSIS
def load(self, data_type: Type[STRUCTURED_STRINGS]) -> STRUCTURED_STRINGS:
"""Loads the data from the HTML or Markdown file.
Args:
data_type: The type of the data to read.
Returns:
The loaded data.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self._get_filepath(data_type), "r") as f:
return data_type(f.read())
def save(self, data: STRUCTURED_STRINGS) -> None:
"""Save data as an HTML or Markdown file.
Args:
data: The data to save as an HTML or Markdown file.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self._get_filepath(type(data)), "w") as f:
f.write(data)
def save_visualizations(
self, data: STRUCTURED_STRINGS
) -> Dict[str, VisualizationType]:
"""Save visualizations for the given data.
Args:
data: The data to save visualizations for.
Returns:
A dictionary of visualization URIs and their types.
"""
filepath = self._get_filepath(type(data))
filepath = filepath.replace("\\", "/")
visualization_type = self._get_visualization_type(type(data))
return {filepath: visualization_type}
def _get_filepath(self, data_type: Type[STRUCTURED_STRINGS]) -> str:
"""Get the file path for the given data type.
Args:
data_type: The type of the data.
Returns:
The file path for the given data type.
Raises:
ValueError: If the data type is not supported.
"""
if issubclass(data_type, CSVString):
filename = CSV_FILENAME
elif issubclass(data_type, HTMLString):
filename = HTML_FILENAME
elif issubclass(data_type, MarkdownString):
filename = MARKDOWN_FILENAME
else:
raise ValueError(
f"Data type {data_type} is not supported by this materializer."
)
return os.path.join(self.uri, filename)
def _get_visualization_type(
self, data_type: Type[STRUCTURED_STRINGS]
) -> VisualizationType:
"""Get the visualization type for the given data type.
Args:
data_type: The type of the data.
Returns:
The visualization type for the given data type.
Raises:
ValueError: If the data type is not supported.
"""
if issubclass(data_type, CSVString):
return VisualizationType.CSV
elif issubclass(data_type, HTMLString):
return VisualizationType.HTML
elif issubclass(data_type, MarkdownString):
return VisualizationType.MARKDOWN
else:
raise ValueError(
f"Data type {data_type} is not supported by this materializer."
)
load(self, data_type)
Loads the data from the HTML or Markdown file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Union[zenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString]] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Union[zenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString] |
The loaded data. |
Source code in zenml/materializers/structured_string_materializer.py
def load(self, data_type: Type[STRUCTURED_STRINGS]) -> STRUCTURED_STRINGS:
"""Loads the data from the HTML or Markdown file.
Args:
data_type: The type of the data to read.
Returns:
The loaded data.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self._get_filepath(data_type), "r") as f:
return data_type(f.read())
save(self, data)
Save data as an HTML or Markdown file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Union[zenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString] |
The data to save as an HTML or Markdown file. |
required |
Source code in zenml/materializers/structured_string_materializer.py
def save(self, data: STRUCTURED_STRINGS) -> None:
"""Save data as an HTML or Markdown file.
Args:
data: The data to save as an HTML or Markdown file.
"""
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(self._get_filepath(type(data)), "w") as f:
f.write(data)
save_visualizations(self, data)
Save visualizations for the given data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Union[zenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString] |
The data to save visualizations for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.enums.VisualizationType] |
A dictionary of visualization URIs and their types. |
Source code in zenml/materializers/structured_string_materializer.py
def save_visualizations(
self, data: STRUCTURED_STRINGS
) -> Dict[str, VisualizationType]:
"""Save visualizations for the given data.
Args:
data: The data to save visualizations for.
Returns:
A dictionary of visualization URIs and their types.
"""
filepath = self._get_filepath(type(data))
filepath = filepath.replace("\\", "/")
visualization_type = self._get_visualization_type(type(data))
return {filepath: visualization_type}