Model
zenml.model
special
Concepts related to the Model Control Plane feature.
model_version
ModelVersion user facing interface to pass into pipeline or step.
ModelVersion (BaseModel)
pydantic-model
ModelVersion class to pass into pipeline or step to set it into a model context.
name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage. If skipped new model version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack.
Source code in zenml/model/model_version.py
class ModelVersion(BaseModel):
"""ModelVersion class to pass into pipeline or step to set it into a model context.
name: The name of the model.
license: The license under which the model is created.
description: The description of the model.
audience: The target audience of the model.
use_cases: The use cases of the model.
limitations: The known limitations of the model.
trade_offs: The tradeoffs of the model.
ethics: The ethical implications of the model.
tags: Tags associated with the model.
version: The model version name, number or stage is optional and points model context
to a specific version/stage. If skipped new model version will be created.
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
"""
name: str
license: Optional[str] = None
description: Optional[str] = None
audience: Optional[str] = None
use_cases: Optional[str] = None
limitations: Optional[str] = None
trade_offs: Optional[str] = None
ethics: Optional[str] = None
tags: Optional[List[str]] = None
version: Optional[Union[ModelStages, int, str]] = None
save_models_to_registry: bool = True
suppress_class_validation_warnings: bool = False
was_created_in_this_run: bool = False
_model_id: UUID = PrivateAttr(None)
_id: UUID = PrivateAttr(None)
_number: int = PrivateAttr(None)
#########################
# Public methods #
#########################
@property
def id(self) -> UUID:
"""Get version id from the Model Control Plane.
Returns:
ID of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
if self._id is None:
try:
self._get_or_create_model_version()
except RuntimeError:
logger.info(
f"Model version `{self.version}` doesn't exist "
"and cannot be fetched from the Model Control Plane."
)
return self._id
@property
def model_id(self) -> UUID:
"""Get model id from the Model Control Plane.
Returns:
The UUID of the model containing this model version.
"""
if self._model_id is None:
self._get_or_create_model()
return self._model_id
@property
def number(self) -> int:
"""Get version number from the Model Control Plane.
Returns:
Number of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
if self._number is None:
try:
self._get_or_create_model_version()
except RuntimeError:
logger.info(
f"Model version `{self.version}` doesn't exist "
"and cannot be fetched from the Model Control Plane."
)
return self._number
@property
def stage(self) -> Optional[ModelStages]:
"""Get version stage from the Model Control Plane.
Returns:
Stage of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
try:
stage = self._get_or_create_model_version().stage
if stage:
return ModelStages(stage)
except RuntimeError:
logger.info(
f"Model version `{self.version}` doesn't exist "
"and cannot be fetched from the Model Control Plane."
)
return None
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
"""Load artifact from the Model Control Plane.
Args:
name: Name of the artifact to load.
version: Version of the artifact to load.
Returns:
The loaded artifact.
Raises:
ValueError: if the model version is not linked to any artifact with
the given name and version.
"""
from zenml.artifacts.utils import load_artifact
from zenml.models import ArtifactVersionResponse
artifact = self.get_artifact(name=name, version=version)
if not isinstance(artifact, ArtifactVersionResponse):
raise ValueError(
f"Version {self.version} of model {self.name} does not have "
f"an artifact with name {name} and version {version}."
)
return load_artifact(artifact.id, str(artifact.version))
def _try_get_as_external_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ExternalArtifact"]:
from zenml import ExternalArtifact, get_pipeline_context
try:
get_pipeline_context()
except RuntimeError:
return None
ea = ExternalArtifact(name=name, version=version, model_version=self)
return ea
def get_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the artifact linked to this model version.
Args:
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_artifact(
name=name,
version=version,
)
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the model artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
)
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the data artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
)
def get_deployment_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the deployment artifact linked to this model version.
Args:
name: The name of the deployment artifact to retrieve.
version: The version of the deployment artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the deployment artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_deployment_artifact(
name=name,
version=version,
)
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> None:
"""Sets this Model Version to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in target stage or raise.
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)
def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.
This function can be used to log metadata for current model version.
Args:
metadata: The metadata to log.
"""
from zenml.client import Client
response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
@property
def metadata(self) -> Dict[str, "MetadataType"]:
"""Get model version metadata.
Returns:
The model version metadata.
Raises:
RuntimeError: If the model version metadata cannot be fetched.
"""
response = self._get_or_create_model_version(hydrate=True)
if response.run_metadata is None:
raise RuntimeError(
"Failed to fetch metadata of this model version."
)
return {
name: response.value
for name, response in response.run_metadata.items()
}
#########################
# Internal methods #
#########################
class Config:
"""Config class."""
smart_union = True
def __eq__(self, other: object) -> bool:
"""Check two ModelVersions for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, ModelVersion):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
@root_validator(pre=True)
def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate all in one.
Args:
values: Dict of values.
Returns:
Dict of validated values.
"""
suppress_class_validation_warnings = values.get(
"suppress_class_validation_warnings", False
)
version = values.get("version", None)
if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
):
logger.info(
f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage."
)
if str(version).isnumeric() and not suppress_class_validation_warnings:
logger.info(
f"`version` `{version}` is numeric and will be fetched using version number."
)
values["suppress_class_validation_warnings"] = True
return values
def _validate_config_in_runtime(self) -> None:
"""Validate that config doesn't conflict with runtime environment."""
self._get_or_create_model_version()
def _get_or_create_model(self) -> "ModelResponse":
"""This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
The model based on configuration.
"""
from zenml.client import Client
from zenml.models import ModelRequest
zenml_client = Client()
try:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
except KeyError:
model_request = ModelRequest(
name=self.name,
license=self.license,
description=self.description,
audience=self.audience,
use_cases=self.use_cases,
limitations=self.limitations,
trade_offs=self.trade_offs,
ethics=self.ethics,
tags=self.tags,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
)
model_request = ModelRequest.parse_obj(model_request)
try:
model = zenml_client.zen_store.create_model(
model=model_request
)
logger.info(f"New model `{self.name}` was created implicitly.")
except EntityExistsError:
# this is backup logic, if model was created somehow in between get and create calls
pass
finally:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
self._model_id = model.id
return model
def _get_model_version(self) -> "ModelVersionResponse":
"""This method gets a model version from Model Control Plane.
Returns:
The model version based on configuration.
"""
from zenml.client import Client
zenml_client = Client()
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
if not self._id:
self._id = mv.id
return mv
def _get_or_create_model_version(
self, hydrate: bool = False
) -> "ModelVersionResponse":
"""This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model is fetched. Model
name is controlled by the `name` parameter.
Model Version returned by this method is resolved based on model version:
- If `version` is None, a new model version is created, if not created by other steps in same run.
- If `version` is not None a model version will be fetched based on the version:
- If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.
Args:
hydrate: Whether to return a hydrated version of the model version.
Returns:
The model version based on configuration.
Raises:
RuntimeError: if the model version needs to be created, but provided name is reserved
"""
from zenml.client import Client
from zenml.models import ModelVersionRequest
model = self._get_or_create_model()
zenml_client = Client()
model_version_request = ModelVersionRequest(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequest.parse_obj(model_version_request)
try:
if not self.version:
try:
from zenml import get_step_context
context = get_step_context()
except RuntimeError:
pass
else:
# if inside a step context we loop over all
# model version configuration to find, if the
# model version for current model was already
# created in the current run, not to create
# new model versions
pipeline_mv = context.pipeline_run.config.model_version
if (
pipeline_mv
and pipeline_mv.was_created_in_this_run
and pipeline_mv.name == self.name
and pipeline_mv.version is not None
):
self.version = pipeline_mv.version
else:
for step in context.pipeline_run.steps.values():
step_mv = step.config.model_version
if (
step_mv
and step_mv.was_created_in_this_run
and step_mv.name == self.name
and step_mv.version is not None
):
self.version = step_mv.version
break
if self.version:
model_version = self._get_model_version()
else:
raise KeyError
except KeyError:
if (
self.version
and str(self.version).lower() in ModelStages.values()
):
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"it matches one of the possible model version stages. If you "
"are aiming to fetch model version by stage, check if the "
"model version in given stage exists. It might be missing, if "
"the pipeline promoting model version to this stage failed,"
" as an example. You can explore model versions using "
f"`zenml model version list {self.name}` CLI command."
)
if str(self.version).isnumeric():
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"numeric model version names are reserved. If you "
"are aiming to fetch model version by number, check if the "
"model version with given number exists. It might be missing, if "
"the pipeline creating model version failed,"
" as an example. You can explore model versions using "
f"`zenml model version list {self.name}` CLI command."
)
model_version = zenml_client.zen_store.create_model_version(
model_version=mv_request
)
self.version = model_version.name
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
self._id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version
def _merge(self, model_version: "ModelVersion") -> None:
self.license = self.license or model_version.license
self.description = self.description or model_version.description
self.audience = self.audience or model_version.audience
self.use_cases = self.use_cases or model_version.use_cases
self.limitations = self.limitations or model_version.limitations
self.trade_offs = self.trade_offs or model_version.trade_offs
self.ethics = self.ethics or model_version.ethics
if model_version.tags is not None:
self.tags = list(
{t for t in self.tags or []}.union(set(model_version.tags))
)
def __hash__(self) -> int:
"""Get hash of the `ModelVersion`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
id: UUID
property
readonly
Get version id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
metadata: Dict[str, MetadataType]
property
readonly
Get model version metadata.
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The model version metadata. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the model version metadata cannot be fetched. |
model_id: UUID
property
readonly
Get model id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
The UUID of the model containing this model version. |
number: int
property
readonly
Get version number from the Model Control Plane.
Returns:
Type | Description |
---|---|
int |
Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
stage: Optional[zenml.enums.ModelStages]
property
readonly
Get version stage from the Model Control Plane.
Returns:
Type | Description |
---|---|
Optional[zenml.enums.ModelStages] |
Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
Config
Config class.
Source code in zenml/model/model_version.py
class Config:
"""Config class."""
smart_union = True
__eq__(self, other)
special
Check two ModelVersions for equality.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
object |
object to compare with |
required |
Returns:
Type | Description |
---|---|
bool |
True, if equal, False otherwise. |
Source code in zenml/model/model_version.py
def __eq__(self, other: object) -> bool:
"""Check two ModelVersions for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, ModelVersion):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
__hash__(self)
special
Get hash of the ModelVersion
.
Returns:
Type | Description |
---|---|
int |
Hash function results |
Source code in zenml/model/model_version.py
def __hash__(self) -> int:
"""Get hash of the `ModelVersion`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
get_artifact(self, name, version=None)
Get the artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the artifact to retrieve. |
required |
version |
Optional[str] |
The version of the artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Inside pipeline context |
ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the artifact or None |
Source code in zenml/model/model_version.py
def get_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the artifact linked to this model version.
Args:
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_artifact(
name=name,
version=version,
)
get_data_artifact(self, name, version=None)
Get the data artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the data artifact to retrieve. |
required |
version |
Optional[str] |
The version of the data artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Inside pipeline context |
ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the data artifact or None |
Source code in zenml/model/model_version.py
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the data artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
)
get_deployment_artifact(self, name, version=None)
Get the deployment artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the deployment artifact to retrieve. |
required |
version |
Optional[str] |
The version of the deployment artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Inside pipeline context |
ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the deployment artifact or None |
Source code in zenml/model/model_version.py
def get_deployment_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the deployment artifact linked to this model version.
Args:
name: The name of the deployment artifact to retrieve.
version: The version of the deployment artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the deployment artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_deployment_artifact(
name=name,
version=version,
)
get_model_artifact(self, name, version=None)
Get the model artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model artifact to retrieve. |
required |
version |
Optional[str] |
The version of the model artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Inside pipeline context |
ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the model artifact or None |
Source code in zenml/model/model_version.py
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for latest/non-versioned)
Returns:
Inside pipeline context: ExternalArtifact object as a lazy loader
Outside of pipeline context: Specific version of the model artifact or None
"""
if response := self._try_get_as_external_artifact(name, version):
return response
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
)
get_pipeline_run(self, name)
Get pipeline run linked to this version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the pipeline run to retrieve. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
PipelineRun as PipelineRunResponse |
Source code in zenml/model/model_version.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
load_artifact(self, name, version=None)
Load artifact from the Model Control Plane.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the artifact to load. |
required |
version |
Optional[str] |
Version of the artifact to load. |
None |
Returns:
Type | Description |
---|---|
Any |
The loaded artifact. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the model version is not linked to any artifact with the given name and version. |
Source code in zenml/model/model_version.py
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
"""Load artifact from the Model Control Plane.
Args:
name: Name of the artifact to load.
version: Version of the artifact to load.
Returns:
The loaded artifact.
Raises:
ValueError: if the model version is not linked to any artifact with
the given name and version.
"""
from zenml.artifacts.utils import load_artifact
from zenml.models import ArtifactVersionResponse
artifact = self.get_artifact(name=name, version=version)
if not isinstance(artifact, ArtifactVersionResponse):
raise ValueError(
f"Version {self.version} of model {self.name} does not have "
f"an artifact with name {name} and version {version}."
)
return load_artifact(artifact.id, str(artifact.version))
log_metadata(self, metadata)
Log model version metadata.
This function can be used to log metadata for current model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
Source code in zenml/model/model_version.py
def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.
This function can be used to log metadata for current model version.
Args:
metadata: The metadata to log.
"""
from zenml.client import Client
response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
set_stage(self, stage, force=False)
Sets this Model Version to a desired stage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Union[str, zenml.enums.ModelStages] |
the target stage for model version. |
required |
force |
bool |
whether to force archiving of current model version in target stage or raise. |
False |
Source code in zenml/model/model_version.py
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> None:
"""Sets this Model Version to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in target stage or raise.
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)
utils
Utility functions for linking step outputs to model versions.
link_artifact_config_to_model_version(artifact_config, artifact_version_id, model_version=None)
Link an artifact config to its model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_config |
ArtifactConfig |
The artifact config to link. |
required |
artifact_version_id |
UUID |
The ID of the artifact to link. |
required |
model_version |
Optional[ModelVersion] |
The model version from the step or pipeline context. |
None |
Source code in zenml/model/utils.py
def link_artifact_config_to_model_version(
artifact_config: ArtifactConfig,
artifact_version_id: UUID,
model_version: Optional["ModelVersion"] = None,
) -> None:
"""Link an artifact config to its model version.
Args:
artifact_config: The artifact config to link.
artifact_version_id: The ID of the artifact to link.
model_version: The model version from the step or pipeline context.
"""
client = Client()
# If the artifact config specifies a model itself then always use that
if artifact_config.model_name is not None:
from zenml.model.model_version import ModelVersion
model_version = ModelVersion(
name=artifact_config.model_name,
version=artifact_config.model_version,
)
if model_version:
model_version._get_or_create_model_version()
model_version_response = model_version._get_model_version()
request = ModelVersionArtifactRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
artifact_version=artifact_version_id,
model=model_version_response.model.id,
model_version=model_version_response.id,
is_model_artifact=artifact_config.is_model_artifact,
is_deployment_artifact=artifact_config.is_deployment_artifact,
)
client.zen_store.create_model_version_artifact_link(request)
link_step_artifacts_to_model(artifact_version_ids)
Links the output artifacts of a step to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_ids |
Dict[str, uuid.UUID] |
The IDs of the published output artifacts. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If called outside of a step. |
Source code in zenml/model/utils.py
def link_step_artifacts_to_model(
artifact_version_ids: Dict[str, UUID],
) -> None:
"""Links the output artifacts of a step to the model.
Args:
artifact_version_ids: The IDs of the published output artifacts.
Raises:
RuntimeError: If called outside of a step.
"""
try:
step_context = get_step_context()
except StepContextError:
raise RuntimeError(
"`link_step_artifacts_to_model` can only be called from within a "
"step."
)
try:
model_version = step_context.model_version
except StepContextError:
model_version = None
logger.debug("No model context found, unable to auto-link artifacts.")
for artifact_name, artifact_version_id in artifact_version_ids.items():
artifact_config = step_context._get_output(
artifact_name
).artifact_config
# Implicit linking
if artifact_config is None and model_version is not None:
artifact_config = ArtifactConfig(name=artifact_name)
logger.info(
f"Implicitly linking artifact `{artifact_name}` to model "
f"`{model_version.name}` version `{model_version.version}`."
)
if artifact_config:
link_artifact_config_to_model_version(
artifact_config=artifact_config,
artifact_version_id=artifact_version_id,
model_version=model_version,
)
log_model_version_metadata(metadata, model_name=None, model_version=None)
Log model version metadata.
This function can be used to log metadata for existing model versions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
model_name |
Optional[str] |
The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
model_version |
Union[zenml.enums.ModelStages, int, str] |
The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If no model name/version is provided and the function is not
called inside a step with configured |
Source code in zenml/model/utils.py
def log_model_version_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.
This function can be used to log metadata for existing model versions.
Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.
Raises:
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model_version` in decorator.
"""
mv = None
try:
step_context = get_step_context()
mv = step_context.model_version
except RuntimeError:
step_context = None
if not step_context and not (model_name and model_version):
raise ValueError(
"Model name and version must be provided unless the function is "
"called inside a step with configured `model_version` in decorator."
)
if mv is None:
from zenml import ModelVersion
mv = ModelVersion(name=model_name, version=model_version)
mv.log_metadata(metadata)