Artifacts
zenml.artifacts
special
artifact_config
Artifact Config classes to support Model Control Plane feature.
ArtifactConfig (BaseModel)
Artifact configuration class.
Can be used in step definitions to define various artifact properties.
Examples:
@step
def my_step() -> Annotated[
int, ArtifactConfig(
name="my_artifact", # override the default artifact name
version=42, # set a custom version
tags=["tag1", "tag2"], # set custom tags
model_name="my_model", # link the artifact to a model
)
]:
return ...
Attributes:
Name | Type | Description |
---|---|---|
name |
Optional[str] |
The name of the artifact. |
version |
Union[int, str] |
The version of the artifact. |
tags |
Optional[List[str]] |
The tags of the artifact. |
model_name |
Optional[str] |
The name of the model to link artifact to. |
model_version |
Union[zenml.enums.ModelStages, str, int] |
The identifier of a version of the model to link the artifact to. It can be an exact version ("my_version"), exact version number (42), stage (ModelStages.PRODUCTION or "production"), or (ModelStages.LATEST or None) for the latest version (default). |
is_model_artifact |
bool |
Whether the artifact is a model artifact. |
is_deployment_artifact |
bool |
Whether the artifact is a deployment artifact. |
Source code in zenml/artifacts/artifact_config.py
class ArtifactConfig(BaseModel):
"""Artifact configuration class.
Can be used in step definitions to define various artifact properties.
Example:
```python
@step
def my_step() -> Annotated[
int, ArtifactConfig(
name="my_artifact", # override the default artifact name
version=42, # set a custom version
tags=["tag1", "tag2"], # set custom tags
model_name="my_model", # link the artifact to a model
)
]:
return ...
```
Attributes:
name: The name of the artifact.
version: The version of the artifact.
tags: The tags of the artifact.
model_name: The name of the model to link artifact to.
model_version: The identifier of a version of the model to link the artifact
to. It can be an exact version ("my_version"), exact version number
(42), stage (ModelStages.PRODUCTION or "production"), or
(ModelStages.LATEST or None) for the latest version (default).
is_model_artifact: Whether the artifact is a model artifact.
is_deployment_artifact: Whether the artifact is a deployment artifact.
"""
name: Optional[str] = None
version: Optional[Union[str, int]] = Field(
default=None, union_mode="smart"
)
tags: Optional[List[str]] = None
run_metadata: Optional[Dict[str, MetadataType]] = None
model_name: Optional[str] = None
model_version: Optional[Union[ModelStages, str, int]] = Field(
default=None, union_mode="smart"
)
is_model_artifact: bool = False
is_deployment_artifact: bool = False
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
# fields defined under base models. If not handled, this raises a warning.
# It is possible to suppress this warning message with the following
# configuration, however the ultimate solution is to rename these fields.
# Even though they do not cause any problems right now, if we are not
# careful we might overwrite some fields protected by pydantic.
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="after")
def artifact_config_validator(self) -> "ArtifactConfig":
"""Model validator for the artifact config.
Raises:
ValueError: If both model_name and model_version is set incorrectly.
Returns:
the validated instance.
"""
if self.model_name is not None and self.model_version is None:
raise ValueError(
f"Creation of new model version from {self.__class__.__name__} "
"is not allowed. Please either keep `model_name` and "
"`model_version` both `None` to get the model version from the "
"step context or specify both at the same time. You can use "
"`ModelStages.LATEST` as `model_version` when latest model "
"version is desired."
)
return self
@property
def _model(self) -> Optional["Model"]:
"""The model linked to this artifact.
Returns:
The model or None if the model version cannot be determined.
"""
try:
model_ = get_step_context().model
except (StepContextError, RuntimeError):
model_ = None
# Check if another model name was specified
if (self.model_name is not None) and (
model_ is None or model_.name != self.model_name
):
# Create a new Model instance with the provided model name and version
from zenml.model.model import Model
on_the_fly_config = Model(
name=self.model_name, version=self.model_version
)
return on_the_fly_config
return model_
artifact_config_validator(self)
Model validator for the artifact config.
Exceptions:
Type | Description |
---|---|
ValueError |
If both model_name and model_version is set incorrectly. |
Returns:
Type | Description |
---|---|
ArtifactConfig |
the validated instance. |
Source code in zenml/artifacts/artifact_config.py
@model_validator(mode="after")
def artifact_config_validator(self) -> "ArtifactConfig":
"""Model validator for the artifact config.
Raises:
ValueError: If both model_name and model_version is set incorrectly.
Returns:
the validated instance.
"""
if self.model_name is not None and self.model_version is None:
raise ValueError(
f"Creation of new model version from {self.__class__.__name__} "
"is not allowed. Please either keep `model_name` and "
"`model_version` both `None` to get the model version from the "
"step context or specify both at the same time. You can use "
"`ModelStages.LATEST` as `model_version` when latest model "
"version is desired."
)
return self
external_artifact
External artifact definition.
ExternalArtifact (ExternalArtifactConfiguration)
External artifacts can be used to provide values as input to ZenML steps.
ZenML steps accept either artifacts (=outputs of other steps), parameters (raw, JSON serializable values) or external artifacts. External artifacts can be used to provide any value as input to a step without needing to write an additional step that returns this value.
The external artifact needs to have either a value associated with it that will be uploaded to the artifact store, or reference an artifact that is already registered in ZenML.
There are several ways to reference an existing artifact: - By providing an artifact ID. - By providing an artifact name and version. If no version is provided, the latest version of that artifact will be used.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
The artifact value. |
required | |
id |
The ID of an artifact that should be referenced by this external artifact. |
required | |
materializer |
The materializer to use for saving the artifact value
to the artifact store. Only used when |
required | |
store_artifact_metadata |
Whether metadata for the artifact should
be stored. Only used when |
required | |
store_artifact_visualizations |
Whether visualizations for the
artifact should be stored. Only used when |
required |
Examples:
from zenml import step, pipeline
from zenml.artifacts.external_artifact import ExternalArtifact
import numpy as np
@step
def my_step(value: np.ndarray) -> None:
print(value)
my_array = np.array([1, 2, 3])
@pipeline
def my_pipeline():
my_step(value=ExternalArtifact(my_array))
Source code in zenml/artifacts/external_artifact.py
class ExternalArtifact(ExternalArtifactConfiguration):
"""External artifacts can be used to provide values as input to ZenML steps.
ZenML steps accept either artifacts (=outputs of other steps), parameters
(raw, JSON serializable values) or external artifacts. External artifacts
can be used to provide any value as input to a step without needing to
write an additional step that returns this value.
The external artifact needs to have either a value associated with it
that will be uploaded to the artifact store, or reference an artifact
that is already registered in ZenML.
There are several ways to reference an existing artifact:
- By providing an artifact ID.
- By providing an artifact name and version. If no version is provided,
the latest version of that artifact will be used.
Args:
value: The artifact value.
id: The ID of an artifact that should be referenced by this external
artifact.
materializer: The materializer to use for saving the artifact value
to the artifact store. Only used when `value` is provided.
store_artifact_metadata: Whether metadata for the artifact should
be stored. Only used when `value` is provided.
store_artifact_visualizations: Whether visualizations for the
artifact should be stored. Only used when `value` is provided.
Example:
```
from zenml import step, pipeline
from zenml.artifacts.external_artifact import ExternalArtifact
import numpy as np
@step
def my_step(value: np.ndarray) -> None:
print(value)
my_array = np.array([1, 2, 3])
@pipeline
def my_pipeline():
my_step(value=ExternalArtifact(my_array))
```
"""
value: Optional[Any] = None
materializer: Optional[MaterializerClassOrSource] = Field(
default=None, union_mode="left_to_right"
)
store_artifact_metadata: bool = True
store_artifact_visualizations: bool = True
@model_validator(mode="after")
def external_artifact_validator(self) -> "ExternalArtifact":
"""Model validator for the external artifact.
Raises:
ValueError: if the value, id and name fields are set incorrectly.
Returns:
the validated instance.
"""
deprecation_msg = (
"Parameter `{param}` or `ExternalArtifact` will be deprecated "
"in upcoming releases. Please use `{substitute}` instead."
)
for param, substitute in [
["id", "Client().get_artifact_version(name_id_or_prefix=<id>)"],
[
"name",
"Client().get_artifact_version(name_id_or_prefix=<name>)",
],
[
"version",
"Client().get_artifact_version(name_id_or_prefix=<name>,version=<version>)",
],
[
"model",
"Client().get_model_version(<model_name>,<model_version>).get_artifact(name)",
],
]:
if getattr(self, param, None):
logger.warning(
deprecation_msg.format(
param=param,
substitute=substitute,
)
)
options = [
getattr(self, field, None) is not None
for field in ["value", "id", "name"]
]
if sum(options) > 1:
raise ValueError(
"Only one of `value`, `id`, or `name` can be provided when "
"creating an external artifact."
)
elif sum(options) == 0:
raise ValueError(
"Either `value`, `id`, or `name` must be provided when "
"creating an external artifact."
)
return self
def upload_by_value(self) -> UUID:
"""Uploads the artifact by value.
Returns:
The uploaded artifact ID.
"""
from zenml.artifacts.utils import save_artifact
artifact_name = f"external_{uuid4()}"
uri = os.path.join("external_artifacts", artifact_name)
logger.info("Uploading external artifact to '%s'.", uri)
artifact = save_artifact(
name=artifact_name,
data=self.value,
extract_metadata=self.store_artifact_metadata,
include_visualizations=self.store_artifact_visualizations,
materializer=self.materializer,
uri=uri,
has_custom_name=False,
manual_save=False,
)
# To avoid duplicate uploads, switch to referencing the uploaded
# artifact by ID
self.id = artifact.id
self.value = None
logger.info("Finished uploading external artifact %s.", self.id)
return self.id
@property
def config(self) -> ExternalArtifactConfiguration:
"""Returns the lightweight config without hard for JSON properties.
Returns:
The config object to be evaluated in runtime by step interface.
"""
return ExternalArtifactConfiguration(
id=self.id,
name=self.name,
version=self.version,
model=self.model,
)
config: ExternalArtifactConfiguration
property
readonly
Returns the lightweight config without hard for JSON properties.
Returns:
Type | Description |
---|---|
ExternalArtifactConfiguration |
The config object to be evaluated in runtime by step interface. |
external_artifact_validator(self)
Model validator for the external artifact.
Exceptions:
Type | Description |
---|---|
ValueError |
if the value, id and name fields are set incorrectly. |
Returns:
Type | Description |
---|---|
ExternalArtifact |
the validated instance. |
Source code in zenml/artifacts/external_artifact.py
@model_validator(mode="after")
def external_artifact_validator(self) -> "ExternalArtifact":
"""Model validator for the external artifact.
Raises:
ValueError: if the value, id and name fields are set incorrectly.
Returns:
the validated instance.
"""
deprecation_msg = (
"Parameter `{param}` or `ExternalArtifact` will be deprecated "
"in upcoming releases. Please use `{substitute}` instead."
)
for param, substitute in [
["id", "Client().get_artifact_version(name_id_or_prefix=<id>)"],
[
"name",
"Client().get_artifact_version(name_id_or_prefix=<name>)",
],
[
"version",
"Client().get_artifact_version(name_id_or_prefix=<name>,version=<version>)",
],
[
"model",
"Client().get_model_version(<model_name>,<model_version>).get_artifact(name)",
],
]:
if getattr(self, param, None):
logger.warning(
deprecation_msg.format(
param=param,
substitute=substitute,
)
)
options = [
getattr(self, field, None) is not None
for field in ["value", "id", "name"]
]
if sum(options) > 1:
raise ValueError(
"Only one of `value`, `id`, or `name` can be provided when "
"creating an external artifact."
)
elif sum(options) == 0:
raise ValueError(
"Either `value`, `id`, or `name` must be provided when "
"creating an external artifact."
)
return self
upload_by_value(self)
Uploads the artifact by value.
Returns:
Type | Description |
---|---|
UUID |
The uploaded artifact ID. |
Source code in zenml/artifacts/external_artifact.py
def upload_by_value(self) -> UUID:
"""Uploads the artifact by value.
Returns:
The uploaded artifact ID.
"""
from zenml.artifacts.utils import save_artifact
artifact_name = f"external_{uuid4()}"
uri = os.path.join("external_artifacts", artifact_name)
logger.info("Uploading external artifact to '%s'.", uri)
artifact = save_artifact(
name=artifact_name,
data=self.value,
extract_metadata=self.store_artifact_metadata,
include_visualizations=self.store_artifact_visualizations,
materializer=self.materializer,
uri=uri,
has_custom_name=False,
manual_save=False,
)
# To avoid duplicate uploads, switch to referencing the uploaded
# artifact by ID
self.id = artifact.id
self.value = None
logger.info("Finished uploading external artifact %s.", self.id)
return self.id
external_artifact_config
External artifact definition.
ExternalArtifactConfiguration (BaseModel)
External artifact configuration.
Lightweight class to pass in the steps for runtime inference.
Source code in zenml/artifacts/external_artifact_config.py
class ExternalArtifactConfiguration(BaseModel):
"""External artifact configuration.
Lightweight class to pass in the steps for runtime inference.
"""
id: Optional[UUID] = None
name: Optional[str] = None
version: Optional[str] = None
model: Optional[Model] = None
@model_validator(mode="after")
def external_artifact_validator(self) -> "ExternalArtifactConfiguration":
"""Model validator for the external artifact configuration.
Raises:
ValueError: if both version and model fields are set.
Returns:
the validated instance.
"""
if self.version and self.model:
raise ValueError(
"Cannot provide both `version` and `model` when "
"creating an external artifact."
)
return self
def get_artifact_version_id(self) -> UUID:
"""Get the artifact.
Returns:
The artifact ID.
Raises:
RuntimeError: If the artifact store of the referenced artifact
is not the same as the one in the active stack.
RuntimeError: If neither the ID nor the name of the artifact was
provided.
"""
from zenml.client import Client
client = Client()
if self.id:
response = client.get_artifact_version(self.id)
elif self.name:
if self.version:
response = client.get_artifact_version(
self.name, version=self.version
)
elif self.model:
response_ = self.model.get_artifact(self.name)
if not isinstance(response_, ArtifactVersionResponse):
raise RuntimeError(
f"Failed to pull artifact `{self.name}` from the Model "
f"(name=`{self.model.name}`, version="
f"`{self.model.version}`). Please validate the "
"input and try again."
)
response = response_
else:
response = client.get_artifact_version(self.name)
else:
raise RuntimeError(
"Either the ID or name of the artifact must be provided. "
"If you created this ExternalArtifact from a value, please "
"ensure that `upload_by_value` was called before trying to "
"fetch the artifact ID."
)
artifact_store_id = client.active_stack.artifact_store.id
if response.artifact_store_id != artifact_store_id:
raise RuntimeError(
f"The artifact {response.name} (ID: {response.id}) "
"referenced by an external artifact is not stored in the "
"artifact store of the active stack. This will lead to "
"issues loading the artifact. Please make sure to only "
"reference artifact versions stored in your active artifact "
"store."
)
self.id = response.id
return self.id
external_artifact_validator(self)
Model validator for the external artifact configuration.
Exceptions:
Type | Description |
---|---|
ValueError |
if both version and model fields are set. |
Returns:
Type | Description |
---|---|
ExternalArtifactConfiguration |
the validated instance. |
Source code in zenml/artifacts/external_artifact_config.py
@model_validator(mode="after")
def external_artifact_validator(self) -> "ExternalArtifactConfiguration":
"""Model validator for the external artifact configuration.
Raises:
ValueError: if both version and model fields are set.
Returns:
the validated instance.
"""
if self.version and self.model:
raise ValueError(
"Cannot provide both `version` and `model` when "
"creating an external artifact."
)
return self
get_artifact_version_id(self)
Get the artifact.
Returns:
Type | Description |
---|---|
UUID |
The artifact ID. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the artifact store of the referenced artifact is not the same as the one in the active stack. |
RuntimeError |
If neither the ID nor the name of the artifact was provided. |
Source code in zenml/artifacts/external_artifact_config.py
def get_artifact_version_id(self) -> UUID:
"""Get the artifact.
Returns:
The artifact ID.
Raises:
RuntimeError: If the artifact store of the referenced artifact
is not the same as the one in the active stack.
RuntimeError: If neither the ID nor the name of the artifact was
provided.
"""
from zenml.client import Client
client = Client()
if self.id:
response = client.get_artifact_version(self.id)
elif self.name:
if self.version:
response = client.get_artifact_version(
self.name, version=self.version
)
elif self.model:
response_ = self.model.get_artifact(self.name)
if not isinstance(response_, ArtifactVersionResponse):
raise RuntimeError(
f"Failed to pull artifact `{self.name}` from the Model "
f"(name=`{self.model.name}`, version="
f"`{self.model.version}`). Please validate the "
"input and try again."
)
response = response_
else:
response = client.get_artifact_version(self.name)
else:
raise RuntimeError(
"Either the ID or name of the artifact must be provided. "
"If you created this ExternalArtifact from a value, please "
"ensure that `upload_by_value` was called before trying to "
"fetch the artifact ID."
)
artifact_store_id = client.active_stack.artifact_store.id
if response.artifact_store_id != artifact_store_id:
raise RuntimeError(
f"The artifact {response.name} (ID: {response.id}) "
"referenced by an external artifact is not stored in the "
"artifact store of the active stack. This will lead to "
"issues loading the artifact. Please make sure to only "
"reference artifact versions stored in your active artifact "
"store."
)
self.id = response.id
return self.id
unmaterialized_artifact
Unmaterialized artifact class.
UnmaterializedArtifact (ArtifactVersionResponse)
Unmaterialized artifact class.
Typing a step input to have this type will cause ZenML to not materialize the artifact. This is useful for steps that need to access the artifact metadata instead of the actual artifact data.
Usage example:
from zenml import step
from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
@step
def my_step(input_artifact: UnmaterializedArtifact):
print(input_artifact.uri)
Source code in zenml/artifacts/unmaterialized_artifact.py
class UnmaterializedArtifact(ArtifactVersionResponse):
"""Unmaterialized artifact class.
Typing a step input to have this type will cause ZenML to not materialize
the artifact. This is useful for steps that need to access the artifact
metadata instead of the actual artifact data.
Usage example:
```python
from zenml import step
from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
@step
def my_step(input_artifact: UnmaterializedArtifact):
print(input_artifact.uri)
```
"""
model_post_init(/, self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Source code in zenml/artifacts/unmaterialized_artifact.py
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, context)
original_model_post_init(self, context)
utils
Utility functions for handling artifacts.
download_artifact_files_from_response(artifact, path, overwrite=False)
Download the given artifact into a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactVersionResponse |
The artifact to download. |
required |
path |
str |
The path to which to download the artifact. |
required |
overwrite |
bool |
Whether to overwrite the file if it already exists. |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If the file already exists and |
Exception |
If the artifact could not be downloaded to the zip file. |
Source code in zenml/artifacts/utils.py
def download_artifact_files_from_response(
artifact: "ArtifactVersionResponse",
path: str,
overwrite: bool = False,
) -> None:
"""Download the given artifact into a file.
Args:
artifact: The artifact to download.
path: The path to which to download the artifact.
overwrite: Whether to overwrite the file if it already exists.
Raises:
FileExistsError: If the file already exists and `overwrite` is `False`.
Exception: If the artifact could not be downloaded to the zip file.
"""
if not overwrite and fileio.exists(path):
raise FileExistsError(
f"File '{path}' already exists and `overwrite` is set to `False`."
)
artifact_store = _get_artifact_store_from_response_or_from_active_stack(
artifact=artifact
)
if filepaths := artifact_store.listdir(artifact.uri):
# save a zipfile to 'path' containing all the files
# in 'filepaths' with compression
try:
with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zipf:
for file in filepaths:
# Ensure 'file' is a string for path operations
# and ZIP entry naming
file_str = (
file.decode() if isinstance(file, bytes) else file
)
file_path = str(Path(artifact.uri) / file_str)
with artifact_store.open(
name=file_path, mode="rb"
) as store_file:
# Use a loop to read and write chunks of the file
# instead of reading the entire file into memory
CHUNK_SIZE = 8192
while True:
if file_content := store_file.read(CHUNK_SIZE):
zipf.writestr(file_str, file_content)
else:
break
except Exception as e:
logger.error(
f"Failed to save artifact '{artifact.id}' to zip file "
f" '{path}': {e}"
)
raise
get_artifacts_versions_of_pipeline_run(pipeline_run, only_produced=False)
Get all artifact versions produced during a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunResponse |
The pipeline run. |
required |
only_produced |
bool |
If only artifact versions produced by the pipeline run should be returned or also cached artifact versions. |
False |
Returns:
Type | Description |
---|---|
List[ArtifactVersionResponse] |
A list of all artifact versions produced during the pipeline run. |
Source code in zenml/artifacts/utils.py
def get_artifacts_versions_of_pipeline_run(
pipeline_run: "PipelineRunResponse", only_produced: bool = False
) -> List["ArtifactVersionResponse"]:
"""Get all artifact versions produced during a pipeline run.
Args:
pipeline_run: The pipeline run.
only_produced: If only artifact versions produced by the pipeline run
should be returned or also cached artifact versions.
Returns:
A list of all artifact versions produced during the pipeline run.
"""
artifact_versions: List["ArtifactVersionResponse"] = []
for step in pipeline_run.steps.values():
if not only_produced or step.status == ExecutionStatus.COMPLETED:
artifact_versions.extend(step.outputs.values())
return artifact_versions
get_producer_step_of_artifact(artifact)
Get the step run that produced a given artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactVersionResponse |
The artifact. |
required |
Returns:
Type | Description |
---|---|
StepRunResponse |
The step run that produced the artifact. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run that created the artifact no longer exists. |
Source code in zenml/artifacts/utils.py
def get_producer_step_of_artifact(
artifact: "ArtifactVersionResponse",
) -> "StepRunResponse":
"""Get the step run that produced a given artifact.
Args:
artifact: The artifact.
Returns:
The step run that produced the artifact.
Raises:
RuntimeError: If the run that created the artifact no longer exists.
"""
if not artifact.producer_step_run_id:
raise RuntimeError(
f"The run that produced the artifact with id '{artifact.id}' no "
"longer exists. This can happen if the run was deleted."
)
return Client().get_run_step(artifact.producer_step_run_id)
load_artifact(name_or_id, version=None)
Load an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name_or_id |
Union[str, uuid.UUID] |
The name or ID of the artifact to load. |
required |
version |
Optional[str] |
The version of the artifact to load, if |
None |
Returns:
Type | Description |
---|---|
Any |
The loaded artifact. |
Source code in zenml/artifacts/utils.py
def load_artifact(
name_or_id: Union[str, UUID],
version: Optional[str] = None,
) -> Any:
"""Load an artifact.
Args:
name_or_id: The name or ID of the artifact to load.
version: The version of the artifact to load, if `name_or_id` is a
name. If not provided, the latest version will be loaded.
Returns:
The loaded artifact.
"""
artifact = Client().get_artifact_version(name_or_id, version)
try:
step_run = get_step_context().step_run
client = Client()
client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
loaded_artifact_versions={artifact.name: artifact.id}
),
)
except RuntimeError:
pass # Cannot link to step run if called outside of a step
return load_artifact_from_response(artifact)
load_artifact_from_response(artifact)
Load the given artifact into memory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactVersionResponse |
The artifact to load. |
required |
Returns:
Type | Description |
---|---|
Any |
The artifact loaded into memory. |
Source code in zenml/artifacts/utils.py
def load_artifact_from_response(artifact: "ArtifactVersionResponse") -> Any:
"""Load the given artifact into memory.
Args:
artifact: The artifact to load.
Returns:
The artifact loaded into memory.
"""
artifact_store = _get_artifact_store_from_response_or_from_active_stack(
artifact=artifact
)
return _load_artifact_from_uri(
materializer=artifact.materializer,
data_type=artifact.data_type,
uri=artifact.uri,
artifact_store=artifact_store,
)
load_artifact_visualization(artifact, index=0, zen_store=None, encode_image=False)
Load a visualization of the given artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactVersionResponse |
The artifact to visualize. |
required |
index |
int |
The index of the visualization to load. |
0 |
zen_store |
Optional[BaseZenStore] |
The ZenStore to use for finding the artifact store. If not provided, the client's ZenStore will be used. |
None |
encode_image |
bool |
Whether to base64 encode image visualizations. |
False |
Returns:
Type | Description |
---|---|
LoadedVisualization |
The loaded visualization. |
Exceptions:
Type | Description |
---|---|
DoesNotExistException |
If the artifact does not have the requested visualization or if the visualization was not found in the artifact store. |
Source code in zenml/artifacts/utils.py
def load_artifact_visualization(
artifact: "ArtifactVersionResponse",
index: int = 0,
zen_store: Optional["BaseZenStore"] = None,
encode_image: bool = False,
) -> LoadedVisualization:
"""Load a visualization of the given artifact.
Args:
artifact: The artifact to visualize.
index: The index of the visualization to load.
zen_store: The ZenStore to use for finding the artifact store. If not
provided, the client's ZenStore will be used.
encode_image: Whether to base64 encode image visualizations.
Returns:
The loaded visualization.
Raises:
DoesNotExistException: If the artifact does not have the requested
visualization or if the visualization was not found in the artifact
store.
"""
# Get the visualization to load
if not artifact.visualizations:
raise DoesNotExistException(
f"Artifact '{artifact.id}' has no visualizations."
)
if index < 0 or index >= len(artifact.visualizations):
raise DoesNotExistException(
f"Artifact '{artifact.id}' only has {len(artifact.visualizations)} "
f"visualizations, but index {index} was requested."
)
visualization = artifact.visualizations[index]
# Load the visualization from the artifact's artifact store
if not artifact.artifact_store_id:
raise DoesNotExistException(
f"Artifact '{artifact.id}' cannot be visualized because the "
"underlying artifact store was deleted."
)
artifact_store = _load_artifact_store(
artifact_store_id=artifact.artifact_store_id, zen_store=zen_store
)
try:
mode = "rb" if visualization.type == VisualizationType.IMAGE else "r"
value = _load_file_from_artifact_store(
uri=visualization.uri,
artifact_store=artifact_store,
mode=mode,
)
# Encode image visualizations if requested
if visualization.type == VisualizationType.IMAGE and encode_image:
value = base64.b64encode(bytes(value))
return LoadedVisualization(type=visualization.type, value=value)
finally:
artifact_store.cleanup()
load_model_from_metadata(model_uri)
Load a zenml model artifact from a json file.
This function is used to load information from a Yaml file that was created by the save_model_metadata function. The information in the Yaml file is used to load the model into memory in the inference environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the artifact to extract the metadata from. |
required |
Returns:
Type | Description |
---|---|
Any |
The ML model object loaded into memory. |
Source code in zenml/artifacts/utils.py
def load_model_from_metadata(model_uri: str) -> Any:
"""Load a zenml model artifact from a json file.
This function is used to load information from a Yaml file that was created
by the save_model_metadata function. The information in the Yaml file is
used to load the model into memory in the inference environment.
Args:
model_uri: the artifact to extract the metadata from.
Returns:
The ML model object loaded into memory.
"""
# Load the model from its metadata
artifact_versions_by_uri = Client().list_artifact_versions(uri=model_uri)
if artifact_versions_by_uri.total == 1:
artifact_store = (
_get_artifact_store_from_response_or_from_active_stack(
artifact_versions_by_uri.items[0]
)
)
else:
artifact_store = Client().active_stack.artifact_store
with artifact_store.open(
os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
) as f:
metadata = read_yaml(f.name)
data_type = metadata["datatype"]
materializer = metadata["materializer"]
model = _load_artifact_from_uri(
materializer=materializer,
data_type=data_type,
uri=model_uri,
artifact_store=artifact_store,
)
# Switch to eval mode if the model is a torch model
try:
import torch.nn as nn
if isinstance(model, nn.Module):
model.eval()
except ImportError:
pass
return model
log_artifact_metadata(metadata, artifact_name=None, artifact_version=None)
Log artifact metadata.
This function can be used to log metadata for either existing artifact versions or artifact versions that are newly created in the same step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
artifact_name |
Optional[str] |
The name of the artifact to log metadata for. Can be omitted when being called inside a step with only one output. |
None |
artifact_version |
Optional[str] |
The version of the artifact to log metadata for. If
not provided, when being called inside a step that produces an
artifact named |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If no artifact name is provided and the function is not called inside a step with a single output, or, if neither an artifact nor an output with the given name exists. |
Source code in zenml/artifacts/utils.py
def log_artifact_metadata(
metadata: Dict[str, "MetadataType"],
artifact_name: Optional[str] = None,
artifact_version: Optional[str] = None,
) -> None:
"""Log artifact metadata.
This function can be used to log metadata for either existing artifact
versions or artifact versions that are newly created in the same step.
Args:
metadata: The metadata to log.
artifact_name: The name of the artifact to log metadata for. Can
be omitted when being called inside a step with only one output.
artifact_version: The version of the artifact to log metadata for. If
not provided, when being called inside a step that produces an
artifact named `artifact_name`, the metadata will be associated to
the corresponding newly created artifact. Or, if not provided when
being called outside of a step, or in a step that does not produce
any artifact named `artifact_name`, the metadata will be associated
to the latest version of that artifact.
Raises:
ValueError: If no artifact name is provided and the function is not
called inside a step with a single output, or, if neither an
artifact nor an output with the given name exists.
"""
try:
step_context = get_step_context()
in_step_outputs = (artifact_name in step_context._outputs) or (
not artifact_name and len(step_context._outputs) == 1
)
except RuntimeError:
step_context = None
in_step_outputs = False
if not step_context or not in_step_outputs or artifact_version:
if not artifact_name:
raise ValueError(
"Artifact name must be provided unless the function is called "
"inside a step with a single output."
)
client = Client()
response = client.get_artifact_version(artifact_name, artifact_version)
client.create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,
)
else:
try:
step_context.add_output_metadata(
metadata=metadata, output_name=artifact_name
)
except StepContextError as e:
raise ValueError(e)
save_artifact(data, name, version=None, tags=None, extract_metadata=True, include_visualizations=True, has_custom_name=True, user_metadata=None, materializer=None, uri=None, is_model_artifact=False, is_deployment_artifact=False, manual_save=True)
Upload and publish an artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the artifact. |
required |
data |
Any |
The artifact data. |
required |
version |
Union[int, str] |
The version of the artifact. If not provided, a new auto-incremented version will be used. |
None |
tags |
Optional[List[str]] |
Tags to associate with the artifact. |
None |
extract_metadata |
bool |
If artifact metadata should be extracted and returned. |
True |
include_visualizations |
bool |
If artifact visualizations should be generated. |
True |
has_custom_name |
bool |
If the artifact name is custom and should be listed in the dashboard "Artifacts" tab. |
True |
user_metadata |
Optional[Dict[str, MetadataType]] |
User-provided metadata to store with the artifact. |
None |
materializer |
Optional[MaterializerClassOrSource] |
The materializer to use for saving the artifact to the artifact store. |
None |
uri |
Optional[str] |
The URI within the artifact store to upload the artifact
to. If not provided, the artifact will be uploaded to
|
None |
is_model_artifact |
bool |
If the artifact is a model artifact. |
False |
is_deployment_artifact |
bool |
If the artifact is a deployment artifact. |
False |
manual_save |
bool |
If this function is called manually and should therefore link the artifact to the current step run. |
True |
Returns:
Type | Description |
---|---|
ArtifactVersionResponse |
The saved artifact response. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If artifact URI already exists. |
EntityExistsError |
If artifact version already exists. |
Source code in zenml/artifacts/utils.py
def save_artifact(
data: Any,
name: str,
version: Optional[Union[int, str]] = None,
tags: Optional[List[str]] = None,
extract_metadata: bool = True,
include_visualizations: bool = True,
has_custom_name: bool = True,
user_metadata: Optional[Dict[str, "MetadataType"]] = None,
materializer: Optional["MaterializerClassOrSource"] = None,
uri: Optional[str] = None,
is_model_artifact: bool = False,
is_deployment_artifact: bool = False,
manual_save: bool = True,
) -> "ArtifactVersionResponse":
"""Upload and publish an artifact.
Args:
name: The name of the artifact.
data: The artifact data.
version: The version of the artifact. If not provided, a new
auto-incremented version will be used.
tags: Tags to associate with the artifact.
extract_metadata: If artifact metadata should be extracted and returned.
include_visualizations: If artifact visualizations should be generated.
has_custom_name: If the artifact name is custom and should be listed in
the dashboard "Artifacts" tab.
user_metadata: User-provided metadata to store with the artifact.
materializer: The materializer to use for saving the artifact to the
artifact store.
uri: The URI within the artifact store to upload the artifact
to. If not provided, the artifact will be uploaded to
`custom_artifacts/{name}/{version}`.
is_model_artifact: If the artifact is a model artifact.
is_deployment_artifact: If the artifact is a deployment artifact.
manual_save: If this function is called manually and should therefore
link the artifact to the current step run.
Returns:
The saved artifact response.
Raises:
RuntimeError: If artifact URI already exists.
EntityExistsError: If artifact version already exists.
"""
from zenml.materializers.materializer_registry import (
materializer_registry,
)
from zenml.utils import source_utils
client = Client()
# Get or create the artifact
try:
artifact = client.list_artifacts(name=name)[0]
if artifact.has_custom_name != has_custom_name:
client.update_artifact(
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
)
except IndexError:
try:
artifact = client.zen_store.create_artifact(
ArtifactRequest(
name=name,
has_custom_name=has_custom_name,
tags=tags,
)
)
except EntityExistsError:
artifact = client.list_artifacts(name=name)[0]
# Get the current artifact store
artifact_store = client.active_stack.artifact_store
# Build and check the artifact URI
if not uri:
uri = os.path.join("custom_artifacts", name, str(uuid4()))
if not uri.startswith(artifact_store.path):
uri = os.path.join(artifact_store.path, uri)
if manual_save and artifact_store.exists(uri):
# This check is only necessary for manual saves as we already check
# it when creating the directory for step output artifacts
other_artifacts = client.list_artifact_versions(uri=uri, size=1)
if other_artifacts and (other_artifact := other_artifacts[0]):
raise RuntimeError(
f"Cannot save new artifact {name} version to URI "
f"{uri} because the URI is already used by artifact "
f"{other_artifact.name} (version {other_artifact.version})."
)
artifact_store.makedirs(uri)
# Find and initialize the right materializer class
if isinstance(materializer, type):
materializer_class = materializer
elif materializer:
materializer_class = source_utils.load_and_validate_class(
materializer, expected_class=BaseMaterializer
)
else:
materializer_class = materializer_registry[type(data)]
materializer_object = materializer_class(uri)
# Force URIs to have forward slashes
materializer_object.uri = materializer_object.uri.replace("\\", "/")
# Save the artifact to the artifact store
data_type = type(data)
materializer_object.validate_type_compatibility(data_type)
materializer_object.save(data)
# Save visualizations of the artifact
visualizations: List[ArtifactVisualizationRequest] = []
if include_visualizations:
try:
vis_data = materializer_object.save_visualizations(data)
for vis_uri, vis_type in vis_data.items():
vis_model = ArtifactVisualizationRequest(
type=vis_type,
uri=vis_uri,
)
visualizations.append(vis_model)
except Exception as e:
logger.warning(
f"Failed to save visualization for output artifact '{name}': "
f"{e}"
)
# Save metadata of the artifact
artifact_metadata: Dict[str, "MetadataType"] = {}
if extract_metadata:
try:
artifact_metadata = materializer_object.extract_full_metadata(data)
artifact_metadata.update(user_metadata or {})
except Exception as e:
logger.warning(
f"Failed to extract metadata for output artifact '{name}': {e}"
)
# Create the artifact version
def _create_version() -> Optional[ArtifactVersionResponse]:
artifact_version = ArtifactVersionRequest(
artifact_id=artifact.id,
version=version,
tags=tags,
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
uri=materializer_object.uri,
materializer=source_utils.resolve(materializer_object.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
)
try:
return client.zen_store.create_artifact_version(
artifact_version=artifact_version
)
except EntityExistsError:
return None
response = None
if not version:
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
# Get new artifact version
version = _get_new_artifact_version(name)
if response := _create_version():
break
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create artifact version `{version}` for "
f"artifact `{name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
if not response:
raise EntityExistsError(
f"Failed to create new artifact version for artifact "
f"`{name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
)
else:
response = _create_version()
if not response:
raise EntityExistsError(
f"Failed to create artifact version `{version}` for artifact "
f"`{name}`. Given version already exists."
)
if artifact_metadata:
client.create_run_metadata(
metadata=artifact_metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,
)
if manual_save:
try:
error_message = "step run"
step_context = get_step_context()
step_run = step_context.step_run
client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
saved_artifact_versions={name: response.id}
),
)
error_message = "model"
model = step_context.model
if model:
from zenml.model.utils import link_artifact_to_model
link_artifact_to_model(
artifact_version_id=response.id,
model=model,
is_model_artifact=is_model_artifact,
is_deployment_artifact=is_deployment_artifact,
)
except (RuntimeError, StepContextError):
logger.debug(f"Unable to link saved artifact to {error_message}.")
return response
save_model_metadata(model_artifact)
Save a zenml model artifact metadata to a YAML file.
This function is used to extract and save information from a zenml model artifact such as the model type and materializer. The extracted information will be the key to loading the model into memory in the inference environment.
datatype: the model type. This is the path to the model class. materializer: The path to the materializer class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_artifact |
ArtifactVersionResponse |
the artifact to extract the metadata from. |
required |
Returns:
Type | Description |
---|---|
str |
The path to the temporary file where the model metadata is saved |
Source code in zenml/artifacts/utils.py
def save_model_metadata(model_artifact: "ArtifactVersionResponse") -> str:
"""Save a zenml model artifact metadata to a YAML file.
This function is used to extract and save information from a zenml model
artifact such as the model type and materializer. The extracted information
will be the key to loading the model into memory in the inference
environment.
datatype: the model type. This is the path to the model class.
materializer: The path to the materializer class.
Args:
model_artifact: the artifact to extract the metadata from.
Returns:
The path to the temporary file where the model metadata is saved
"""
metadata = dict()
metadata["datatype"] = model_artifact.data_type
metadata["materializer"] = model_artifact.materializer
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
write_yaml(f.name, metadata)
return f.name