Model
zenml.model
special
Concepts related to the Model Control Plane feature.
lazy_load
Model Version Data Lazy Loader definition.
ModelVersionDataLazyLoader (BaseModel)
Model Version Data Lazy Loader helper class.
It helps the inner codes to fetch proper artifact, model version metadata or artifact metadata from the model version during runtime time of the step.
Source code in zenml/model/lazy_load.py
class ModelVersionDataLazyLoader(BaseModel):
"""Model Version Data Lazy Loader helper class.
It helps the inner codes to fetch proper artifact,
model version metadata or artifact metadata from the
model version during runtime time of the step.
"""
model_name: str
model_version: Optional[str] = None
artifact_name: Optional[str] = None
artifact_version: Optional[str] = None
metadata_name: Optional[str] = None
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
# fields defined under base models. If not handled, this raises a warning.
# It is possible to suppress this warning message with the following
# configuration, however the ultimate solution is to rename these fields.
# Even though they do not cause any problems right now, if we are not
# careful we might overwrite some fields protected by pydantic.
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="before")
@classmethod
@before_validator_handler
def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate all in one.
Args:
data: Dict of values.
Returns:
Dict of validated values.
Raises:
ValueError: If the model version id, but call is not internal.
"""
if data.get("model_version", None) is None:
try:
context = get_pipeline_context()
if (
not context.model
or context.model.name != data["model_name"]
):
raise ValueError(
"`version` must be set if you use the `Model` class "
"directly in the pipeline body, otherwise, you can use "
"`get_pipeline_context().model` to lazy load the current "
"Model Version from the pipeline context."
)
except RuntimeError:
pass
data["suppress_class_validation_warnings"] = True
return data
def _get_model_response(
self, pipeline_run: "PipelineRunResponse"
) -> "ModelVersionResponse":
# if the version/number is None -> return the model in context
if self.model_version is None:
if mv := pipeline_run.model_version:
if mv.model.name != self.model_name:
raise RuntimeError(
"Lazy loading of the model failed, since given name "
f"`{self.model_name}` does not match the model name "
f"in the pipeline context: `{mv.model.name}`."
)
return mv
else:
raise RuntimeError(
"Lazy loading of the model failed, since the model version "
"is not set in the pipeline context."
)
# else return the model version by version
else:
from zenml.client import Client
try:
return Client().get_model_version(
model_name_or_id=self.model_name,
model_version_name_or_number_or_id=self.model_version,
)
except KeyError as e:
raise RuntimeError(
"Lazy loading of the model version failed: "
f"no model `{self.model_name}` with version "
f"`{self.model_version}` could be found."
) from e
model
Model user facing interface to pass into pipeline or step.
Model (BaseModel)
Model class to pass into pipeline or step to set it into a model context.
name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The version name, version number or stage is optional and points model context" to a specific version/stage. If skipped new version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack.
Source code in zenml/model/model.py
class Model(BaseModel):
"""Model class to pass into pipeline or step to set it into a model context.
name: The name of the model.
license: The license under which the model is created.
description: The description of the model.
audience: The target audience of the model.
use_cases: The use cases of the model.
limitations: The known limitations of the model.
trade_offs: The tradeoffs of the model.
ethics: The ethical implications of the model.
tags: Tags associated with the model.
version: The version name, version number or stage is optional and points model context
to a specific version/stage. If skipped new version will be created.
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
"""
name: str
license: Optional[str] = None
description: Optional[str] = None
audience: Optional[str] = None
use_cases: Optional[str] = None
limitations: Optional[str] = None
trade_offs: Optional[str] = None
ethics: Optional[str] = None
tags: Optional[List[str]] = None
version: Optional[Union[ModelStages, int, str]] = Field(
default=None, union_mode="smart"
)
save_models_to_registry: bool = True
# technical attributes
model_version_id: Optional[UUID] = None
suppress_class_validation_warnings: bool = False
_model_id: UUID = PrivateAttr(None)
_number: Optional[int] = PrivateAttr(None)
_created_model_version: bool = PrivateAttr(False)
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
# fields defined under base models. If not handled, this raises a warning.
# It is possible to suppress this warning message with the following
# configuration, however the ultimate solution is to rename these fields.
# Even though they do not cause any problems right now, if we are not
# careful we might overwrite some fields protected by pydantic.
model_config = ConfigDict(protected_namespaces=())
#########################
# Public methods #
#########################
@property
def id(self) -> UUID:
"""Get version id from the Model Control Plane.
Returns:
ID of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
Raises:
RuntimeError: if model version doesn't exist and
cannot be fetched from the Model Control Plane.
"""
if self.model_version_id is None:
try:
mv = self._get_or_create_model_version()
self.model_version_id = mv.id
except RuntimeError as e:
raise RuntimeError(
f"Version `{self.version}` of `{self.name}` model doesn't "
"exist and cannot be fetched from the Model Control Plane."
) from e
return self.model_version_id
@property
def model_id(self) -> UUID:
"""Get model id from the Model Control Plane.
Returns:
The UUID of the model containing this model version.
"""
if self._model_id is None:
self._get_or_create_model()
return self._model_id
@property
def number(self) -> int:
"""Get version number from the Model Control Plane.
Returns:
Number of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
Raises:
KeyError: if model version doesn't exist and
cannot be fetched from the Model Control Plane.
"""
if self._number is None:
try:
mv = self._get_or_create_model_version()
self._number = mv.number
except RuntimeError as e:
raise KeyError(
f"Version `{self.version}` of `{self.name}` model doesn't "
"exist and cannot be fetched from the Model Control Plane."
) from e
return self._number
@property
def stage(self) -> Optional[ModelStages]:
"""Get version stage from the Model Control Plane.
Returns:
Stage of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
try:
stage = self._get_or_create_model_version().stage
if stage:
return ModelStages(stage)
except RuntimeError:
logger.info(
f"Version `{self.version}` of `{self.name}` model doesn't "
"exist and cannot be fetched from the Model Control Plane."
)
return None
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
"""Load artifact from the Model Control Plane.
Args:
name: Name of the artifact to load.
version: Version of the artifact to load.
Returns:
The loaded artifact.
Raises:
ValueError: if the model version is not linked to any artifact with
the given name and version.
"""
from zenml.artifacts.utils import load_artifact
from zenml.models import ArtifactVersionResponse
artifact = self.get_artifact(name=name, version=version)
if not isinstance(artifact, ArtifactVersionResponse):
raise ValueError(
f"Version {self.version} of model {self.name} does not have "
f"an artifact with name {name} and version {version}."
)
return load_artifact(artifact.id, str(artifact.version))
def get_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the artifact linked to this model version.
Args:
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the artifact or placeholder in the design time
of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_artifact(
name=name,
version=version,
)
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the model artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
)
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the data artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
)
def get_deployment_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the deployment artifact linked to this model version.
Args:
name: The name of the deployment artifact to retrieve.
version: The version of the deployment artifact to retrieve (None
for latest/non-versioned)
Returns:
Specific version of the deployment artifact or placeholder in the
design time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_deployment_artifact(
name=name,
version=version,
)
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> None:
"""Sets this Model to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in
target stage or raise.
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)
def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.
This function can be used to log metadata for current model version.
Args:
metadata: The metadata to log.
"""
from zenml.client import Client
response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
@property
def run_metadata(self) -> Dict[str, "MetadataType"]:
"""Get model version run metadata.
Returns:
The model version run metadata.
Raises:
RuntimeError: If the model version run metadata cannot be fetched.
"""
from zenml.metadata.lazy_load import RunMetadataLazyGetter
try:
get_pipeline_context()
# avoid exposing too much of internal details by keeping the return type
return RunMetadataLazyGetter( # type: ignore[return-value]
self.name,
self._lazy_version,
)
except RuntimeError:
pass
response = self._get_or_create_model_version(hydrate=True)
if response.run_metadata is None:
raise RuntimeError(
"Failed to fetch metadata of this model version."
)
return response.run_metadata
def delete_artifact(
self,
name: str,
version: Optional[str] = None,
only_link: bool = True,
delete_metadata: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete the artifact linked to this model version.
Args:
name: The name of the artifact to delete.
version: The version of the artifact to delete (None for
latest/non-versioned)
only_link: Whether to only delete the link to the artifact.
delete_metadata: Whether to delete the metadata of the artifact.
delete_from_artifact_store: Whether to delete the artifact from the
artifact store.
"""
from zenml.client import Client
from zenml.models import ArtifactVersionResponse
artifact_version = self.get_artifact(name, version)
if isinstance(artifact_version, ArtifactVersionResponse):
client = Client()
client.delete_model_version_artifact_link(
model_version_id=self.id,
artifact_version_id=artifact_version.id,
)
if not only_link:
client.delete_artifact_version(
name_id_or_prefix=artifact_version.id,
delete_metadata=delete_metadata,
delete_from_artifact_store=delete_from_artifact_store,
)
def delete_all_artifacts(
self,
only_link: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete all artifacts linked to this model version.
Args:
only_link: Whether to only delete the link to the artifact.
delete_from_artifact_store: Whether to delete the artifact from
the artifact store.
"""
from zenml.client import Client
client = Client()
if not only_link and delete_from_artifact_store:
mv = self._get_model_version()
artifact_responses = mv.data_artifacts
artifact_responses.update(mv.model_artifacts)
artifact_responses.update(mv.deployment_artifacts)
for artifact_ in artifact_responses.values():
for artifact_response_ in artifact_.values():
client._delete_artifact_from_artifact_store(
artifact_version=artifact_response_
)
client.delete_all_model_version_artifact_links(self.id, only_link)
def _lazy_artifact_get(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
from zenml.models.v2.core.artifact_version import (
LazyArtifactVersionResponse,
)
try:
get_pipeline_context()
return LazyArtifactVersionResponse(
lazy_load_name=name,
lazy_load_version=version,
lazy_load_model_name=self.name,
lazy_load_model_version=self._lazy_version,
)
except RuntimeError:
pass
return None
def __eq__(self, other: object) -> bool:
"""Check two Models for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, Model):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
@model_validator(mode="before")
@classmethod
@before_validator_handler
def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate all in one.
Args:
data: Dict of values.
Returns:
Dict of validated values.
Raises:
ValueError: If the model version id, but call is not internal.
"""
suppress_class_validation_warnings = data.get(
"suppress_class_validation_warnings",
False,
)
if not suppress_class_validation_warnings and data.get(
"model_version_id", None
):
raise ValueError(
"`model_version_id` field is for internal use only"
)
version = data.get("version", None)
if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
):
logger.info(
f"Version `{version}` matches one of the possible "
"`ModelStages` and will be fetched using stage."
)
if str(version).isnumeric() and not suppress_class_validation_warnings:
logger.info(
f"`version` `{version}` is numeric and will be fetched "
"using version number."
)
data["suppress_class_validation_warnings"] = True
return data
def _get_or_create_model(self) -> "ModelResponse":
"""This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
The model based on configuration.
"""
from zenml.client import Client
from zenml.models import ModelRequest
zenml_client = Client()
if self.model_version_id:
mv = zenml_client.get_model_version(
model_version_name_or_number_or_id=self.model_version_id,
)
model = mv.model
else:
try:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
except KeyError:
model_request = ModelRequest(
name=self.name,
license=self.license,
description=self.description,
audience=self.audience,
use_cases=self.use_cases,
limitations=self.limitations,
trade_offs=self.trade_offs,
ethics=self.ethics,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
save_models_to_registry=self.save_models_to_registry,
)
model_request = ModelRequest.model_validate(model_request)
try:
model = zenml_client.zen_store.create_model(
model=model_request
)
logger.info(
f"New model `{self.name}` was created implicitly."
)
except EntityExistsError:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
self._model_id = model.id
return model
def _get_model_version(
self, hydrate: bool = True
) -> "ModelVersionResponse":
"""This method gets a model version from Model Control Plane.
Args:
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model version based on configuration.
"""
from zenml.client import Client
zenml_client = Client()
if self.model_version_id:
mv = zenml_client.get_model_version(
model_version_name_or_number_or_id=self.model_version_id,
hydrate=hydrate,
)
else:
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
hydrate=hydrate,
)
self.model_version_id = mv.id
difference: Dict[str, Any] = {}
if mv.metadata:
if self.description and mv.description != self.description:
difference["description"] = {
"config": self.description,
"db": mv.description,
}
if self.tags:
configured_tags = set(self.tags)
db_tags = {t.name for t in mv.tags}
if db_tags != configured_tags:
difference["tags added"] = list(configured_tags - db_tags)
difference["tags removed"] = list(db_tags - configured_tags)
if difference:
logger.warning(
"Provided model version configuration does not match existing model "
f"version `{self.name}::{self.version}` with the following "
f"changes: {difference}. If you want to update the model version "
"configuration, please use the `zenml model version update` command."
)
return mv
def _get_or_create_model_version(
self, hydrate: bool = False
) -> "ModelVersionResponse":
"""This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model
is fetched. Model name is controlled by the `name` parameter.
Model Version returned by this method is resolved based on model version:
- If `version` is None, a new model version is created, if not created
by other steps in same run.
- If `version` is not None a model version will be fetched based on the
version:
- If `version` is set to an integer or digit string, the model
version with the matching number will be fetched.
- If `version` is set to a string, the model version with the
matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the
matching stage will be fetched.
Args:
hydrate: Whether to return a hydrated version of the model version.
Returns:
The model version based on configuration.
Raises:
RuntimeError: if the model version needs to be created, but
provided name is reserved.
RuntimeError: if the model version cannot be created.
"""
from zenml.client import Client
from zenml.models import ModelVersionRequest
model = self._get_or_create_model()
# backup logic, if the Model class is used directly from the code
if isinstance(self.version, str):
self.version = format_name_template(self.version)
try:
if self.version or self.model_version_id:
model_version = self._get_model_version()
else:
raise KeyError
except KeyError:
if (
self.version
and str(self.version).lower() in ModelStages.values()
):
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"it matches one of the possible model version stages. If you "
"are aiming to fetch model version by stage, check if the "
"model version in given stage exists. It might be missing, if "
"the pipeline promoting model version to this stage failed,"
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
if str(self.version).isnumeric():
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"numeric model version names are reserved. If you "
"are aiming to fetch model version by number, check if the "
"model version with given number exists. It might be missing, if "
"the pipeline creating model version failed,"
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
client = Client()
model_version_request = ModelVersionRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
model_version = client.zen_store.create_model_version(
model_version=model_version_request
)
self._created_model_version = True
logger.info(
"Created new model version `%s` for model `%s`.",
model_version.name,
self.name,
)
self.version = model_version.name
self.model_version_id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version
def __hash__(self) -> int:
"""Get hash of the `Model`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
def _prepare_model_version_before_step_launch(
self,
pipeline_run: "PipelineRunResponse",
step_run: Optional["StepRunResponse"],
return_logs: bool,
) -> Tuple[str, "PipelineRunResponse", Optional["StepRunResponse"]]:
"""Prepares model version inside pipeline run.
Args:
pipeline_run: pipeline run
step_run: step run (passed only if model version is defined in a step explicitly)
return_logs: whether to return logs or not
Returns:
Logs related to the Dashboard URL to show later.
"""
from zenml.client import Client
from zenml.models import PipelineRunUpdate, StepRunUpdate
logs = ""
# copy Model instance to prevent corrupting configs of the
# subsequent runs, if they share the same config object
self_copy = self.model_copy()
# in case request is within the step and no self-configuration is provided
# try reuse what's in the pipeline run first
if step_run is None and pipeline_run.model_version is not None:
self_copy.version = pipeline_run.model_version.name
self_copy.model_version_id = pipeline_run.model_version.id
# otherwise try to fill the templated name, if needed
elif isinstance(self_copy.version, str):
if pipeline_run.start_time:
start_time = pipeline_run.start_time
else:
start_time = datetime.datetime.now(datetime.timezone.utc)
self_copy.version = format_name_template(
self_copy.version,
date=start_time.strftime("%Y_%m_%d"),
time=start_time.strftime("%H_%M_%S_%f"),
)
# if exact model not yet defined - try to get/create and update it
# back to the run accordingly
if self_copy.model_version_id is None:
model_version_response = self_copy._get_or_create_model_version()
client = Client()
# update the configured model version id in runs accordingly
if step_run:
step_run = client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
model_version_id=model_version_response.id
),
)
else:
pipeline_run = client.zen_store.update_run(
run_id=pipeline_run.id,
run_update=PipelineRunUpdate(
model_version_id=model_version_response.id
),
)
if return_logs:
from zenml.utils.cloud_utils import try_get_model_version_url
if logs_to_show := try_get_model_version_url(
model_version_response
):
logs = logs_to_show
else:
logs = (
"Models can be viewed in the dashboard using ZenML Pro. Sign up "
"for a free trial at https://www.zenml.io/pro/"
)
self.model_version_id = self_copy.model_version_id
return logs, pipeline_run, step_run
@property
def _lazy_version(self) -> Optional[str]:
"""Get version name for lazy loader.
This getter ensures that new model version
creation is never triggered here.
Returns:
Version name or None if it was not set
"""
if self._number is not None:
return str(self._number)
elif self.version is not None:
if isinstance(self.version, ModelStages):
return self.version.value
return str(self.version)
return None
id: UUID
property
readonly
Get version id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if model version doesn't exist and cannot be fetched from the Model Control Plane. |
model_id: UUID
property
readonly
Get model id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
The UUID of the model containing this model version. |
number: int
property
readonly
Get version number from the Model Control Plane.
Returns:
Type | Description |
---|---|
int |
Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
Exceptions:
Type | Description |
---|---|
KeyError |
if model version doesn't exist and cannot be fetched from the Model Control Plane. |
run_metadata: Dict[str, MetadataType]
property
readonly
Get model version run metadata.
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The model version run metadata. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the model version run metadata cannot be fetched. |
stage: Optional[zenml.enums.ModelStages]
property
readonly
Get version stage from the Model Control Plane.
Returns:
Type | Description |
---|---|
Optional[zenml.enums.ModelStages] |
Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
__eq__(self, other)
special
Check two Models for equality.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
object |
object to compare with |
required |
Returns:
Type | Description |
---|---|
bool |
True, if equal, False otherwise. |
Source code in zenml/model/model.py
def __eq__(self, other: object) -> bool:
"""Check two Models for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, Model):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
__hash__(self)
special
Get hash of the Model
.
Returns:
Type | Description |
---|---|
int |
Hash function results |
Source code in zenml/model/model.py
def __hash__(self) -> int:
"""Get hash of the `Model`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
delete_all_artifacts(self, only_link=True, delete_from_artifact_store=False)
Delete all artifacts linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
only_link |
bool |
Whether to only delete the link to the artifact. |
True |
delete_from_artifact_store |
bool |
Whether to delete the artifact from the artifact store. |
False |
Source code in zenml/model/model.py
def delete_all_artifacts(
self,
only_link: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete all artifacts linked to this model version.
Args:
only_link: Whether to only delete the link to the artifact.
delete_from_artifact_store: Whether to delete the artifact from
the artifact store.
"""
from zenml.client import Client
client = Client()
if not only_link and delete_from_artifact_store:
mv = self._get_model_version()
artifact_responses = mv.data_artifacts
artifact_responses.update(mv.model_artifacts)
artifact_responses.update(mv.deployment_artifacts)
for artifact_ in artifact_responses.values():
for artifact_response_ in artifact_.values():
client._delete_artifact_from_artifact_store(
artifact_version=artifact_response_
)
client.delete_all_model_version_artifact_links(self.id, only_link)
delete_artifact(self, name, version=None, only_link=True, delete_metadata=True, delete_from_artifact_store=False)
Delete the artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the artifact to delete. |
required |
version |
Optional[str] |
The version of the artifact to delete (None for latest/non-versioned) |
None |
only_link |
bool |
Whether to only delete the link to the artifact. |
True |
delete_metadata |
bool |
Whether to delete the metadata of the artifact. |
True |
delete_from_artifact_store |
bool |
Whether to delete the artifact from the artifact store. |
False |
Source code in zenml/model/model.py
def delete_artifact(
self,
name: str,
version: Optional[str] = None,
only_link: bool = True,
delete_metadata: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete the artifact linked to this model version.
Args:
name: The name of the artifact to delete.
version: The version of the artifact to delete (None for
latest/non-versioned)
only_link: Whether to only delete the link to the artifact.
delete_metadata: Whether to delete the metadata of the artifact.
delete_from_artifact_store: Whether to delete the artifact from the
artifact store.
"""
from zenml.client import Client
from zenml.models import ArtifactVersionResponse
artifact_version = self.get_artifact(name, version)
if isinstance(artifact_version, ArtifactVersionResponse):
client = Client()
client.delete_model_version_artifact_link(
model_version_id=self.id,
artifact_version_id=artifact_version.id,
)
if not only_link:
client.delete_artifact_version(
name_id_or_prefix=artifact_version.id,
delete_metadata=delete_metadata,
delete_from_artifact_store=delete_from_artifact_store,
)
get_artifact(self, name, version=None)
Get the artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the artifact to retrieve. |
required |
version |
Optional[str] |
The version of the artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the artifact linked to this model version.
Args:
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the artifact or placeholder in the design time
of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_artifact(
name=name,
version=version,
)
get_data_artifact(self, name, version=None)
Get the data artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the data artifact to retrieve. |
required |
version |
Optional[str] |
The version of the data artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the data artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the data artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
)
get_deployment_artifact(self, name, version=None)
Get the deployment artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the deployment artifact to retrieve. |
required |
version |
Optional[str] |
The version of the deployment artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the deployment artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_deployment_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the deployment artifact linked to this model version.
Args:
name: The name of the deployment artifact to retrieve.
version: The version of the deployment artifact to retrieve (None
for latest/non-versioned)
Returns:
Specific version of the deployment artifact or placeholder in the
design time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_deployment_artifact(
name=name,
version=version,
)
get_model_artifact(self, name, version=None)
Get the model artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model artifact to retrieve. |
required |
version |
Optional[str] |
The version of the model artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the model artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the model artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
)
get_pipeline_run(self, name)
Get pipeline run linked to this version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the pipeline run to retrieve. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
PipelineRun as PipelineRunResponse |
Source code in zenml/model/model.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
load_artifact(self, name, version=None)
Load artifact from the Model Control Plane.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the artifact to load. |
required |
version |
Optional[str] |
Version of the artifact to load. |
None |
Returns:
Type | Description |
---|---|
Any |
The loaded artifact. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the model version is not linked to any artifact with the given name and version. |
Source code in zenml/model/model.py
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
"""Load artifact from the Model Control Plane.
Args:
name: Name of the artifact to load.
version: Version of the artifact to load.
Returns:
The loaded artifact.
Raises:
ValueError: if the model version is not linked to any artifact with
the given name and version.
"""
from zenml.artifacts.utils import load_artifact
from zenml.models import ArtifactVersionResponse
artifact = self.get_artifact(name=name, version=version)
if not isinstance(artifact, ArtifactVersionResponse):
raise ValueError(
f"Version {self.version} of model {self.name} does not have "
f"an artifact with name {name} and version {version}."
)
return load_artifact(artifact.id, str(artifact.version))
log_metadata(self, metadata)
Log model version metadata.
This function can be used to log metadata for current model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
Source code in zenml/model/model.py
def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.
This function can be used to log metadata for current model version.
Args:
metadata: The metadata to log.
"""
from zenml.client import Client
response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
model_post_init(/, self, context)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
BaseModel |
The BaseModel instance. |
required |
context |
Any |
The context. |
required |
Source code in zenml/model/model.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
set_stage(self, stage, force=False)
Sets this Model to a desired stage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Union[str, zenml.enums.ModelStages] |
the target stage for model version. |
required |
force |
bool |
whether to force archiving of current model version in target stage or raise. |
False |
Source code in zenml/model/model.py
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> None:
"""Sets this Model to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in
target stage or raise.
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)
utils
Utility functions for linking step outputs to model versions.
link_artifact_to_model(artifact_version, model=None, is_model_artifact=False, is_deployment_artifact=False)
Link the artifact to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version |
ArtifactVersionResponse |
The artifact version to link. |
required |
model |
Optional[Model] |
The model to link to. |
None |
is_model_artifact |
bool |
Whether the artifact is a model artifact. |
False |
is_deployment_artifact |
bool |
Whether the artifact is a deployment artifact. |
False |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If called outside of a step. |
Source code in zenml/model/utils.py
def link_artifact_to_model(
artifact_version: ArtifactVersionResponse,
model: Optional["Model"] = None,
is_model_artifact: bool = False,
is_deployment_artifact: bool = False,
) -> None:
"""Link the artifact to the model.
Args:
artifact_version: The artifact version to link.
model: The model to link to.
is_model_artifact: Whether the artifact is a model artifact.
is_deployment_artifact: Whether the artifact is a deployment artifact.
Raises:
RuntimeError: If called outside of a step.
"""
if not model:
is_issue = False
try:
step_context = get_step_context()
model = step_context.model
except StepContextError:
is_issue = True
if model is None or is_issue:
raise RuntimeError(
"`link_artifact_to_model` called without `model` parameter "
"and configured model context cannot be identified. Consider "
"passing the `model` explicitly or configuring it in "
"@step or @pipeline decorator."
)
model_version = model._get_or_create_model_version()
artifact_config = ArtifactConfig(
is_model_artifact=is_model_artifact,
is_deployment_artifact=is_deployment_artifact,
)
link_artifact_version_to_model_version(
artifact_version=artifact_version,
model_version=model_version,
artifact_config=artifact_config,
)
link_artifact_version_to_model_version(artifact_version, model_version, artifact_config=None)
Link an artifact version to a model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version |
ArtifactVersionResponse |
The artifact version to link. |
required |
model_version |
ModelVersionResponse |
The model version to link. |
required |
artifact_config |
Optional[zenml.artifacts.artifact_config.ArtifactConfig] |
Output artifact configuration. |
None |
Source code in zenml/model/utils.py
def link_artifact_version_to_model_version(
artifact_version: ArtifactVersionResponse,
model_version: ModelVersionResponse,
artifact_config: Optional[ArtifactConfig] = None,
) -> None:
"""Link an artifact version to a model version.
Args:
artifact_version: The artifact version to link.
model_version: The model version to link.
artifact_config: Output artifact configuration.
"""
if artifact_config:
is_model_artifact = artifact_config.is_model_artifact
is_deployment_artifact = artifact_config.is_deployment_artifact
else:
is_model_artifact = False
is_deployment_artifact = False
client = Client()
client.zen_store.create_model_version_artifact_link(
ModelVersionArtifactRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
artifact_version=artifact_version.id,
model=model_version.model.id,
model_version=model_version.id,
is_model_artifact=is_model_artifact,
is_deployment_artifact=is_deployment_artifact,
)
)
link_service_to_model(service_id, model=None, model_version_id=None)
Links a service to a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_id |
UUID |
The ID of the service to link to the model. |
required |
model |
Optional[Model] |
The model to link the service to. |
None |
model_version_id |
Optional[uuid.UUID] |
The ID of the model version to link the service to. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no model is provided and the model context cannot be identified. |
Source code in zenml/model/utils.py
def link_service_to_model(
service_id: UUID,
model: Optional["Model"] = None,
model_version_id: Optional[UUID] = None,
) -> None:
"""Links a service to a model.
Args:
service_id: The ID of the service to link to the model.
model: The model to link the service to.
model_version_id: The ID of the model version to link the service to.
Raises:
RuntimeError: If no model is provided and the model context cannot be
identified.
"""
client = Client()
# If no model is provided, try to get it from the context
if not model and not model_version_id:
is_issue = False
try:
step_context = get_step_context()
model = step_context.model
except StepContextError:
is_issue = True
if model is None or is_issue:
raise RuntimeError(
"`link_service_to_model` called without `model` parameter "
"and configured model context cannot be identified. Consider "
"passing the `model` explicitly or configuring it in "
"@step or @pipeline decorator."
)
model_version_id = (
model_version_id or model._get_or_create_model_version().id
if model
else None
)
update_service = ServiceUpdate(model_version_id=model_version_id)
client.zen_store.update_service(
service_id=service_id, update=update_service
)
log_model_metadata(metadata, model_name=None, model_version=None)
Log model version metadata.
This function can be used to log metadata for existing model versions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
model_name |
Optional[str] |
The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
model_version |
Union[zenml.enums.ModelStages, int, str] |
The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If no model name/version is provided and the function is not
called inside a step with configured |
Source code in zenml/model/utils.py
def log_model_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.
This function can be used to log metadata for existing model versions.
Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model` in decorator.
Raises:
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model` in decorator.
"""
logger.warning(
"The `log_model_metadata` function is deprecated and will soon be "
"removed. Please use `log_metadata` instead."
)
if model_name and model_version:
from zenml import Model
mv = Model(name=model_name, version=model_version)
else:
try:
step_context = get_step_context()
except RuntimeError:
raise ValueError(
"Model name and version must be provided unless the function is "
"called inside a step with configured `model` in decorator."
)
mv = step_context.model
mv.log_metadata(metadata)