Numpy
zenml.integrations.numpy
special
Initialization of the Numpy integration.
NumpyIntegration (Integration)
Definition of Numpy integration for ZenML.
Source code in zenml/integrations/numpy/__init__.py
class NumpyIntegration(Integration):
"""Definition of Numpy integration for ZenML."""
NAME = NUMPY
REQUIREMENTS = ["numpy<2.0.0"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.numpy import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/numpy/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.numpy import materializers # noqa
materializers
special
Initialization of the Numpy materializer.
numpy_materializer
Implementation of the ZenML NumPy materializer.
NumpyMaterializer (BaseMaterializer)
Materializer to read data to and from pandas.
Source code in zenml/integrations/numpy/materializers/numpy_materializer.py
class NumpyMaterializer(BaseMaterializer):
"""Materializer to read data to and from pandas."""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (np.ndarray,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
def load(self, data_type: Type[Any]) -> "Any":
"""Reads a numpy array from a `.npy` file.
Args:
data_type: The type of the data to read.
Raises:
ImportError: If pyarrow is not installed.
Returns:
The numpy array.
"""
numpy_file = os.path.join(self.uri, NUMPY_FILENAME)
if self.artifact_store.exists(numpy_file):
with self.artifact_store.open(numpy_file, "rb") as f:
return np.load(f, allow_pickle=True)
elif self.artifact_store.exists(os.path.join(self.uri, DATA_FILENAME)):
logger.warning(
"A legacy artifact was found. "
"This artifact was created with an older version of "
"ZenML. You can still use it, but it will be "
"converted to the new format on the next materialization."
)
try:
# Import old materializer dependencies
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
from zenml.utils import yaml_utils
# Read numpy array from parquet file
shape_dict = yaml_utils.read_json(
os.path.join(self.uri, SHAPE_FILENAME)
)
shape_tuple = tuple(shape_dict.values())
with self.artifact_store.open(
os.path.join(self.uri, DATA_FILENAME), "rb"
) as f:
input_stream = pa.input_stream(f)
data = pq.read_table(input_stream)
vals = getattr(data.to_pandas(), DATA_VAR).values
return np.reshape(vals, shape_tuple)
except ImportError:
raise ImportError(
"You have an old version of a `NumpyMaterializer` ",
"data artifact stored in the artifact store ",
"as a `.parquet` file, which requires `pyarrow` for reading. ",
"You can install `pyarrow` by running `pip install pyarrow`.",
)
def save(self, arr: "NDArray[Any]") -> None:
"""Writes a np.ndarray to the artifact store as a `.npy` file.
Args:
arr: The numpy array to write.
"""
with self.artifact_store.open(
os.path.join(self.uri, NUMPY_FILENAME), "wb"
) as f:
np.save(f, arr)
def save_visualizations(
self, arr: "NDArray[Any]"
) -> Dict[str, VisualizationType]:
"""Saves visualizations for a numpy array.
If the array is 1D, a histogram is saved. If the array is 2D or 3D with
3 or 4 channels, an image is saved.
Args:
arr: The numpy array to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
if not np.issubdtype(arr.dtype, np.number):
return {}
try:
# Save histogram for 1D arrays
if len(arr.shape) == 1:
histogram_path = os.path.join(self.uri, "histogram.png")
histogram_path = histogram_path.replace("\\", "/")
self._save_histogram(histogram_path, arr)
return {histogram_path: VisualizationType.IMAGE}
# Save as image for 3D arrays with 3 or 4 channels
if len(arr.shape) == 3 and arr.shape[2] in [3, 4]:
image_path = os.path.join(self.uri, "image.png")
image_path = image_path.replace("\\", "/")
self._save_image(image_path, arr)
return {image_path: VisualizationType.IMAGE}
except ImportError:
logger.info(
"Skipping visualization of numpy array because matplotlib "
"is not installed. To install matplotlib, run "
"`pip install matplotlib`."
)
return {}
def _save_histogram(self, output_path: str, arr: "NDArray[Any]") -> None:
"""Saves a histogram of a numpy array.
Args:
output_path: The path to save the histogram to.
arr: The numpy array of which to save the histogram.
"""
import matplotlib.pyplot as plt
plt.hist(arr)
with self.artifact_store.open(output_path, "wb") as f:
plt.savefig(f)
plt.close()
def _save_image(self, output_path: str, arr: "NDArray[Any]") -> None:
"""Saves a numpy array as an image.
Args:
output_path: The path to save the image to.
arr: The numpy array to save.
"""
from matplotlib.image import imsave
with self.artifact_store.open(output_path, "wb") as f:
imsave(f, arr)
def extract_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
if np.issubdtype(arr.dtype, np.number):
return self._extract_numeric_metadata(arr)
elif np.issubdtype(arr.dtype, np.unicode_) or np.issubdtype(
arr.dtype, np.object_
):
return self._extract_text_metadata(arr)
else:
return {}
def _extract_numeric_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extracts numeric metadata from a numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
A dictionary of metadata.
"""
min_val = np.min(arr).item()
max_val = np.max(arr).item()
numpy_metadata: Dict[str, "MetadataType"] = {
"shape": tuple(arr.shape),
"dtype": DType(arr.dtype.type),
"mean": np.mean(arr).item(),
"std": np.std(arr).item(),
"min": min_val,
"max": max_val,
}
return numpy_metadata
def _extract_text_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extracts text metadata from a numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
A dictionary of metadata.
"""
text = " ".join(arr)
words = text.split()
word_counts = Counter(words)
unique_words = len(word_counts)
total_words = len(words)
most_common_word, most_common_count = word_counts.most_common(1)[0]
text_metadata: Dict[str, "MetadataType"] = {
"shape": tuple(arr.shape),
"dtype": DType(arr.dtype.type),
"unique_words": unique_words,
"total_words": total_words,
"most_common_word": most_common_word,
"most_common_count": most_common_count,
}
return text_metadata
extract_metadata(self, arr)
Extract metadata from the given numpy array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr |
NDArray[Any] |
The numpy array to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/integrations/numpy/materializers/numpy_materializer.py
def extract_metadata(
self, arr: "NDArray[Any]"
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given numpy array.
Args:
arr: The numpy array to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
if np.issubdtype(arr.dtype, np.number):
return self._extract_numeric_metadata(arr)
elif np.issubdtype(arr.dtype, np.unicode_) or np.issubdtype(
arr.dtype, np.object_
):
return self._extract_text_metadata(arr)
else:
return {}
load(self, data_type)
Reads a numpy array from a .npy
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Exceptions:
Type | Description |
---|---|
ImportError |
If pyarrow is not installed. |
Returns:
Type | Description |
---|---|
Any |
The numpy array. |
Source code in zenml/integrations/numpy/materializers/numpy_materializer.py
def load(self, data_type: Type[Any]) -> "Any":
"""Reads a numpy array from a `.npy` file.
Args:
data_type: The type of the data to read.
Raises:
ImportError: If pyarrow is not installed.
Returns:
The numpy array.
"""
numpy_file = os.path.join(self.uri, NUMPY_FILENAME)
if self.artifact_store.exists(numpy_file):
with self.artifact_store.open(numpy_file, "rb") as f:
return np.load(f, allow_pickle=True)
elif self.artifact_store.exists(os.path.join(self.uri, DATA_FILENAME)):
logger.warning(
"A legacy artifact was found. "
"This artifact was created with an older version of "
"ZenML. You can still use it, but it will be "
"converted to the new format on the next materialization."
)
try:
# Import old materializer dependencies
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
from zenml.utils import yaml_utils
# Read numpy array from parquet file
shape_dict = yaml_utils.read_json(
os.path.join(self.uri, SHAPE_FILENAME)
)
shape_tuple = tuple(shape_dict.values())
with self.artifact_store.open(
os.path.join(self.uri, DATA_FILENAME), "rb"
) as f:
input_stream = pa.input_stream(f)
data = pq.read_table(input_stream)
vals = getattr(data.to_pandas(), DATA_VAR).values
return np.reshape(vals, shape_tuple)
except ImportError:
raise ImportError(
"You have an old version of a `NumpyMaterializer` ",
"data artifact stored in the artifact store ",
"as a `.parquet` file, which requires `pyarrow` for reading. ",
"You can install `pyarrow` by running `pip install pyarrow`.",
)
save(self, arr)
Writes a np.ndarray to the artifact store as a .npy
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr |
NDArray[Any] |
The numpy array to write. |
required |
Source code in zenml/integrations/numpy/materializers/numpy_materializer.py
def save(self, arr: "NDArray[Any]") -> None:
"""Writes a np.ndarray to the artifact store as a `.npy` file.
Args:
arr: The numpy array to write.
"""
with self.artifact_store.open(
os.path.join(self.uri, NUMPY_FILENAME), "wb"
) as f:
np.save(f, arr)
save_visualizations(self, arr)
Saves visualizations for a numpy array.
If the array is 1D, a histogram is saved. If the array is 2D or 3D with 3 or 4 channels, an image is saved.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr |
NDArray[Any] |
The numpy array to visualize. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.enums.VisualizationType] |
A dictionary of visualization URIs and their types. |
Source code in zenml/integrations/numpy/materializers/numpy_materializer.py
def save_visualizations(
self, arr: "NDArray[Any]"
) -> Dict[str, VisualizationType]:
"""Saves visualizations for a numpy array.
If the array is 1D, a histogram is saved. If the array is 2D or 3D with
3 or 4 channels, an image is saved.
Args:
arr: The numpy array to visualize.
Returns:
A dictionary of visualization URIs and their types.
"""
if not np.issubdtype(arr.dtype, np.number):
return {}
try:
# Save histogram for 1D arrays
if len(arr.shape) == 1:
histogram_path = os.path.join(self.uri, "histogram.png")
histogram_path = histogram_path.replace("\\", "/")
self._save_histogram(histogram_path, arr)
return {histogram_path: VisualizationType.IMAGE}
# Save as image for 3D arrays with 3 or 4 channels
if len(arr.shape) == 3 and arr.shape[2] in [3, 4]:
image_path = os.path.join(self.uri, "image.png")
image_path = image_path.replace("\\", "/")
self._save_image(image_path, arr)
return {image_path: VisualizationType.IMAGE}
except ImportError:
logger.info(
"Skipping visualization of numpy array because matplotlib "
"is not installed. To install matplotlib, run "
"`pip install matplotlib`."
)
return {}