Model
zenml.model
special
Concepts related to the Model Control Plane feature.
lazy_load
Model Version Data Lazy Loader definition.
ModelVersionDataLazyLoader (BaseModel)
Model Version Data Lazy Loader helper class.
It helps the inner codes to fetch proper artifact, model version metadata or artifact metadata from the model version during runtime time of the step.
Source code in zenml/model/lazy_load.py
class ModelVersionDataLazyLoader(BaseModel):
"""Model Version Data Lazy Loader helper class.
It helps the inner codes to fetch proper artifact,
model version metadata or artifact metadata from the
model version during runtime time of the step.
"""
model: Model
artifact_name: Optional[str] = None
artifact_version: Optional[str] = None
metadata_name: Optional[str] = None
model
Model user facing interface to pass into pipeline or step.
Model (BaseModel)
Model class to pass into pipeline or step to set it into a model context.
name: The name of the model.
license: The license under which the model is created.
description: The description of the model.
audience: The target audience of the model.
use_cases: The use cases of the model.
limitations: The known limitations of the model.
trade_offs: The tradeoffs of the model.
ethics: The ethical implications of the model.
tags: Tags associated with the model.
!!! version "The version name, version number or stage is optional and points model context"
to a specific version/stage. If skipped new version will be created.
!!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry,"
if available in active stack.
!!! model_version_id "The ID of a specific Model Version, if given - it will override"
name
and version
settings. Used mostly internally.
Source code in zenml/model/model.py
class Model(BaseModel):
"""Model class to pass into pipeline or step to set it into a model context.
name: The name of the model.
license: The license under which the model is created.
description: The description of the model.
audience: The target audience of the model.
use_cases: The use cases of the model.
limitations: The known limitations of the model.
trade_offs: The tradeoffs of the model.
ethics: The ethical implications of the model.
tags: Tags associated with the model.
version: The version name, version number or stage is optional and points model context
to a specific version/stage. If skipped new version will be created.
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
model_version_id: The ID of a specific Model Version, if given - it will override
`name` and `version` settings. Used mostly internally.
"""
name: str
license: Optional[str] = None
description: Optional[str] = None
audience: Optional[str] = None
use_cases: Optional[str] = None
limitations: Optional[str] = None
trade_offs: Optional[str] = None
ethics: Optional[str] = None
tags: Optional[List[str]] = None
version: Optional[Union[ModelStages, int, str]] = Field(
default=None, union_mode="smart"
)
save_models_to_registry: bool = True
model_version_id: Optional[UUID] = None
suppress_class_validation_warnings: bool = False
was_created_in_this_run: bool = False
_model_id: UUID = PrivateAttr(None)
_number: int = PrivateAttr(None)
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
# fields defined under base models. If not handled, this raises a warning.
# It is possible to suppress this warning message with the following
# configuration, however the ultimate solution is to rename these fields.
# Even though they do not cause any problems right now, if we are not
# careful we might overwrite some fields protected by pydantic.
model_config = ConfigDict(protected_namespaces=())
#########################
# Public methods #
#########################
@property
def id(self) -> UUID:
"""Get version id from the Model Control Plane.
Returns:
ID of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
Raises:
RuntimeError: if model version doesn't exist and
cannot be fetched from the Model Control Plane.
"""
if self.model_version_id is None:
try:
mv = self._get_or_create_model_version()
self.model_version_id = mv.id
except RuntimeError as e:
raise RuntimeError(
f"Version `{self.version}` of `{self.name}` model doesn't "
"exist and cannot be fetched from the Model Control Plane."
) from e
return self.model_version_id
@property
def model_id(self) -> UUID:
"""Get model id from the Model Control Plane.
Returns:
The UUID of the model containing this model version.
"""
if self._model_id is None:
self._get_or_create_model()
return self._model_id
@property
def number(self) -> int:
"""Get version number from the Model Control Plane.
Returns:
Number of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
if self._number is None:
try:
self._get_or_create_model_version()
except RuntimeError:
logger.info(
f"Version `{self.version}` of `{self.name}` model doesn't "
"exist and cannot be fetched from the Model Control Plane."
)
return self._number
@property
def stage(self) -> Optional[ModelStages]:
"""Get version stage from the Model Control Plane.
Returns:
Stage of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
try:
stage = self._get_or_create_model_version().stage
if stage:
return ModelStages(stage)
except RuntimeError:
logger.info(
f"Version `{self.version}` of `{self.name}` model doesn't "
"exist and cannot be fetched from the Model Control Plane."
)
return None
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
"""Load artifact from the Model Control Plane.
Args:
name: Name of the artifact to load.
version: Version of the artifact to load.
Returns:
The loaded artifact.
Raises:
ValueError: if the model version is not linked to any artifact with
the given name and version.
"""
from zenml.artifacts.utils import load_artifact
from zenml.models import ArtifactVersionResponse
artifact = self.get_artifact(name=name, version=version)
if not isinstance(artifact, ArtifactVersionResponse):
raise ValueError(
f"Version {self.version} of model {self.name} does not have "
f"an artifact with name {name} and version {version}."
)
return load_artifact(artifact.id, str(artifact.version))
def get_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the artifact linked to this model version.
Args:
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the artifact or placeholder in the design time
of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_artifact(
name=name,
version=version,
)
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the model artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
)
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the data artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
)
def get_deployment_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the deployment artifact linked to this model version.
Args:
name: The name of the deployment artifact to retrieve.
version: The version of the deployment artifact to retrieve (None
for latest/non-versioned)
Returns:
Specific version of the deployment artifact or placeholder in the
design time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_deployment_artifact(
name=name,
version=version,
)
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> None:
"""Sets this Model to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in
target stage or raise.
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)
def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.
This function can be used to log metadata for current model version.
Args:
metadata: The metadata to log.
"""
from zenml.client import Client
response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
@property
def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
"""Get model version run metadata.
Returns:
The model version run metadata.
Raises:
RuntimeError: If the model version run metadata cannot be fetched.
"""
from zenml.metadata.lazy_load import RunMetadataLazyGetter
from zenml.new.pipelines.pipeline_context import (
get_pipeline_context,
)
try:
get_pipeline_context()
# avoid exposing too much of internal details by keeping the return type
return RunMetadataLazyGetter( # type: ignore[return-value]
self,
None,
None,
)
except RuntimeError:
pass
response = self._get_or_create_model_version(hydrate=True)
if response.run_metadata is None:
raise RuntimeError(
"Failed to fetch metadata of this model version."
)
return response.run_metadata
# TODO: deprecate me
@property
def metadata(self) -> Dict[str, "MetadataType"]:
"""DEPRECATED, use `run_metadata` instead.
Returns:
The model version run metadata.
"""
logger.warning(
"Model `metadata` property is deprecated. Please use "
"`run_metadata` instead."
)
return {k: v.value for k, v in self.run_metadata.items()}
def delete_artifact(
self,
name: str,
version: Optional[str] = None,
only_link: bool = True,
delete_metadata: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete the artifact linked to this model version.
Args:
name: The name of the artifact to delete.
version: The version of the artifact to delete (None for
latest/non-versioned)
only_link: Whether to only delete the link to the artifact.
delete_metadata: Whether to delete the metadata of the artifact.
delete_from_artifact_store: Whether to delete the artifact from the
artifact store.
"""
from zenml.client import Client
from zenml.models import ArtifactVersionResponse
artifact_version = self.get_artifact(name, version)
if isinstance(artifact_version, ArtifactVersionResponse):
client = Client()
client.delete_model_version_artifact_link(
model_version_id=self.id,
artifact_version_id=artifact_version.id,
)
if not only_link:
client.delete_artifact_version(
name_id_or_prefix=artifact_version.id,
delete_metadata=delete_metadata,
delete_from_artifact_store=delete_from_artifact_store,
)
def delete_all_artifacts(
self,
only_link: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete all artifacts linked to this model version.
Args:
only_link: Whether to only delete the link to the artifact.
delete_from_artifact_store: Whether to delete the artifact from
the artifact store.
"""
from zenml.client import Client
client = Client()
if not only_link and delete_from_artifact_store:
mv = self._get_model_version()
artifact_responses = mv.data_artifacts
artifact_responses.update(mv.model_artifacts)
artifact_responses.update(mv.deployment_artifacts)
for artifact_ in artifact_responses.values():
for artifact_response_ in artifact_.values():
client._delete_artifact_from_artifact_store(
artifact_version=artifact_response_
)
client.delete_all_model_version_artifact_links(self.id, only_link)
def _lazy_artifact_get(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
from zenml import get_pipeline_context
from zenml.models.v2.core.artifact_version import (
LazyArtifactVersionResponse,
)
try:
get_pipeline_context()
return LazyArtifactVersionResponse(
lazy_load_name=name,
lazy_load_version=version,
lazy_load_model=Model(
name=self.name, version=self.version or self.number
),
)
except RuntimeError:
pass
return None
def __eq__(self, other: object) -> bool:
"""Check two Models for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, Model):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
@model_validator(mode="before")
@classmethod
@before_validator_handler
def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate all in one.
Args:
data: Dict of values.
Returns:
Dict of validated values.
"""
suppress_class_validation_warnings = (
data.get(
"suppress_class_validation_warnings",
False,
)
or data.get("model_version_id", None) is not None
)
version = data.get("version", None)
if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
):
logger.info(
f"Version `{version}` matches one of the possible "
"`ModelStages` and will be fetched using stage."
)
if str(version).isnumeric() and not suppress_class_validation_warnings:
logger.info(
f"`version` `{version}` is numeric and will be fetched "
"using version number."
)
data["suppress_class_validation_warnings"] = True
return data
def _validate_config_in_runtime(self) -> "ModelVersionResponse":
"""Validate that config doesn't conflict with runtime environment.
Returns:
The model version based on configuration.
"""
return self._get_or_create_model_version()
def _get_or_create_model(self) -> "ModelResponse":
"""This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
The model based on configuration.
"""
from zenml.client import Client
from zenml.models import ModelRequest
zenml_client = Client()
if self.model_version_id:
mv = zenml_client.get_model_version(
model_version_name_or_number_or_id=self.model_version_id,
)
model = mv.model
else:
try:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
except KeyError:
model_request = ModelRequest(
name=self.name,
license=self.license,
description=self.description,
audience=self.audience,
use_cases=self.use_cases,
limitations=self.limitations,
trade_offs=self.trade_offs,
ethics=self.ethics,
tags=self.tags,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
save_models_to_registry=self.save_models_to_registry,
)
model_request = ModelRequest.model_validate(model_request)
try:
model = zenml_client.zen_store.create_model(
model=model_request
)
logger.info(
f"New model `{self.name}` was created implicitly."
)
except EntityExistsError:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
self._model_id = model.id
return model
def _get_model_version(
self, hydrate: bool = True
) -> "ModelVersionResponse":
"""This method gets a model version from Model Control Plane.
Args:
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Returns:
The model version based on configuration.
"""
from zenml.client import Client
zenml_client = Client()
if self.model_version_id:
mv = zenml_client.get_model_version(
model_version_name_or_number_or_id=self.model_version_id,
hydrate=hydrate,
)
else:
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
hydrate=hydrate,
)
self.model_version_id = mv.id
difference: Dict[str, Any] = {}
if mv.metadata:
if self.description and mv.description != self.description:
difference["description"] = {
"config": self.description,
"db": mv.description,
}
if self.tags:
configured_tags = set(self.tags)
db_tags = {t.name for t in mv.tags}
if db_tags != configured_tags:
difference["tags added"] = list(configured_tags - db_tags)
difference["tags removed"] = list(db_tags - configured_tags)
if difference:
logger.warning(
"Provided model version configuration does not match existing model "
f"version `{self.name}::{self.version}` with the following "
f"changes: {difference}. If you want to update the model version "
"configuration, please use the `zenml model version update` command."
)
return mv
def _get_or_create_model_version(
self, hydrate: bool = False
) -> "ModelVersionResponse":
"""This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model
is fetched. Model name is controlled by the `name` parameter.
Model Version returned by this method is resolved based on model version:
- If `version` is None, a new model version is created, if not created
by other steps in same run.
- If `version` is not None a model version will be fetched based on the
version:
- If `version` is set to an integer or digit string, the model
version with the matching number will be fetched.
- If `version` is set to a string, the model version with the
matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the
matching stage will be fetched.
Args:
hydrate: Whether to return a hydrated version of the model version.
Returns:
The model version based on configuration.
Raises:
RuntimeError: if the model version needs to be created, but
provided name is reserved.
RuntimeError: if the model version cannot be created.
"""
from zenml.client import Client
from zenml.models import ModelVersionRequest
model = self._get_or_create_model()
# backup logic, if the Model class is used directly from the code
if isinstance(self.version, str):
self.version = format_name_template(self.version)
zenml_client = Client()
model_version_request = ModelVersionRequest(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=str(self.version) if self.version else None,
description=self.description,
model=model.id,
tags=self.tags,
)
mv_request = ModelVersionRequest.model_validate(model_version_request)
try:
if not self.version:
try:
from zenml import get_step_context
context = get_step_context()
except RuntimeError:
pass
else:
# if inside a step context we loop over all
# model version configuration to find, if the
# model version for current model was already
# created in the current run, not to create
# new model versions
pipeline_mv = context.pipeline_run.config.model
if (
pipeline_mv
and pipeline_mv.was_created_in_this_run
and pipeline_mv.name == self.name
and pipeline_mv.version is not None
):
self.version = pipeline_mv.version
self.model_version_id = pipeline_mv.model_version_id
else:
for step in context.pipeline_run.steps.values():
step_mv = step.config.model
if (
step_mv
and step_mv.was_created_in_this_run
and step_mv.name == self.name
and step_mv.version is not None
):
self.version = step_mv.version
self.model_version_id = (
step_mv.model_version_id
)
break
if self.version or self.model_version_id:
model_version = self._get_model_version()
else:
raise KeyError
except KeyError:
if (
self.version
and str(self.version).lower() in ModelStages.values()
):
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"it matches one of the possible model version stages. If you "
"are aiming to fetch model version by stage, check if the "
"model version in given stage exists. It might be missing, if "
"the pipeline promoting model version to this stage failed,"
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
if str(self.version).isnumeric():
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"numeric model version names are reserved. If you "
"are aiming to fetch model version by number, check if the "
"model version with given number exists. It might be missing, if "
"the pipeline creating model version failed,"
" as an example. You can explore model versions using "
f"`zenml model version list -n {self.name}` CLI command."
)
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
try:
model_version = (
zenml_client.zen_store.create_model_version(
model_version=mv_request
)
)
break
except EntityExistsError as e:
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
raise RuntimeError(
f"Failed to create model version "
f"`{self.version if self.version else 'new'}` "
f"in model `{self.name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
) from e
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create new model version for "
f"model `{self.name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
self.version = model_version.name
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
self.model_version_id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version
def _merge(self, model: "Model") -> None:
self.license = self.license or model.license
self.description = self.description or model.description
self.audience = self.audience or model.audience
self.use_cases = self.use_cases or model.use_cases
self.limitations = self.limitations or model.limitations
self.trade_offs = self.trade_offs or model.trade_offs
self.ethics = self.ethics or model.ethics
if model.tags is not None:
self.tags = list(
{t for t in self.tags or []}.union(set(model.tags))
)
def __hash__(self) -> int:
"""Get hash of the `Model`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
def _prepare_model_version_before_step_launch(
self,
pipeline_run: "PipelineRunResponse",
step_run: Optional["StepRunResponse"],
return_logs: bool,
) -> str:
"""Prepares model version inside pipeline run.
Args:
pipeline_run: pipeline run
step_run: step run (passed only if model version is defined in a step explicitly)
return_logs: whether to return logs or not
Returns:
Logs related to the Dashboard URL to show later.
"""
from zenml.client import Client
from zenml.models import PipelineRunUpdate, StepRunUpdate
logs = ""
# copy Model instance to prevent corrupting configs of the
# subsequent runs, if they share the same config object
self_copy = self.model_copy()
# in case request is within the step and no self-configuration is provided
# try reuse what's in the pipeline run first
if step_run is None and pipeline_run.model_version is not None:
self_copy.version = pipeline_run.model_version.name
self_copy.model_version_id = pipeline_run.model_version.id
# otherwise try to fill the templated name, if needed
elif isinstance(self_copy.version, str):
if pipeline_run.start_time:
start_time = pipeline_run.start_time
else:
start_time = datetime.datetime.now(datetime.timezone.utc)
self_copy.version = format_name_template(
self_copy.version,
date=start_time.strftime("%Y_%m_%d"),
time=start_time.strftime("%H_%M_%S_%f"),
)
# if exact model not yet defined - try to get/create and update it
# back to the run accordingly
if self_copy.model_version_id is None:
model_version_response = self_copy._get_or_create_model_version()
# update the configured model version id in runs accordingly
if step_run:
Client().zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
model_version_id=model_version_response.id
),
)
else:
Client().zen_store.update_run(
run_id=pipeline_run.id,
run_update=PipelineRunUpdate(
model_version_id=model_version_response.id
),
)
if return_logs:
from zenml.utils.cloud_utils import try_get_model_version_url
if logs_to_show := try_get_model_version_url(
model_version_response
):
logs = logs_to_show
else:
logs = (
"Models can be viewed in the dashboard using ZenML Pro. Sign up "
"for a free trial at https://www.zenml.io/pro/"
)
self.model_version_id = self_copy.model_version_id
return logs
id: UUID
property
readonly
Get version id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if model version doesn't exist and cannot be fetched from the Model Control Plane. |
metadata: Dict[str, MetadataType]
property
readonly
DEPRECATED, use run_metadata
instead.
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The model version run metadata. |
model_id: UUID
property
readonly
Get model id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
The UUID of the model containing this model version. |
number: int
property
readonly
Get version number from the Model Control Plane.
Returns:
Type | Description |
---|---|
int |
Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
run_metadata: Dict[str, RunMetadataResponse]
property
readonly
Get model version run metadata.
Returns:
Type | Description |
---|---|
Dict[str, RunMetadataResponse] |
The model version run metadata. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the model version run metadata cannot be fetched. |
stage: Optional[zenml.enums.ModelStages]
property
readonly
Get version stage from the Model Control Plane.
Returns:
Type | Description |
---|---|
Optional[zenml.enums.ModelStages] |
Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
__eq__(self, other)
special
Check two Models for equality.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
object |
object to compare with |
required |
Returns:
Type | Description |
---|---|
bool |
True, if equal, False otherwise. |
Source code in zenml/model/model.py
def __eq__(self, other: object) -> bool:
"""Check two Models for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, Model):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
__hash__(self)
special
Get hash of the Model
.
Returns:
Type | Description |
---|---|
int |
Hash function results |
Source code in zenml/model/model.py
def __hash__(self) -> int:
"""Get hash of the `Model`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
delete_all_artifacts(self, only_link=True, delete_from_artifact_store=False)
Delete all artifacts linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
only_link |
bool |
Whether to only delete the link to the artifact. |
True |
delete_from_artifact_store |
bool |
Whether to delete the artifact from the artifact store. |
False |
Source code in zenml/model/model.py
def delete_all_artifacts(
self,
only_link: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete all artifacts linked to this model version.
Args:
only_link: Whether to only delete the link to the artifact.
delete_from_artifact_store: Whether to delete the artifact from
the artifact store.
"""
from zenml.client import Client
client = Client()
if not only_link and delete_from_artifact_store:
mv = self._get_model_version()
artifact_responses = mv.data_artifacts
artifact_responses.update(mv.model_artifacts)
artifact_responses.update(mv.deployment_artifacts)
for artifact_ in artifact_responses.values():
for artifact_response_ in artifact_.values():
client._delete_artifact_from_artifact_store(
artifact_version=artifact_response_
)
client.delete_all_model_version_artifact_links(self.id, only_link)
delete_artifact(self, name, version=None, only_link=True, delete_metadata=True, delete_from_artifact_store=False)
Delete the artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the artifact to delete. |
required |
version |
Optional[str] |
The version of the artifact to delete (None for latest/non-versioned) |
None |
only_link |
bool |
Whether to only delete the link to the artifact. |
True |
delete_metadata |
bool |
Whether to delete the metadata of the artifact. |
True |
delete_from_artifact_store |
bool |
Whether to delete the artifact from the artifact store. |
False |
Source code in zenml/model/model.py
def delete_artifact(
self,
name: str,
version: Optional[str] = None,
only_link: bool = True,
delete_metadata: bool = True,
delete_from_artifact_store: bool = False,
) -> None:
"""Delete the artifact linked to this model version.
Args:
name: The name of the artifact to delete.
version: The version of the artifact to delete (None for
latest/non-versioned)
only_link: Whether to only delete the link to the artifact.
delete_metadata: Whether to delete the metadata of the artifact.
delete_from_artifact_store: Whether to delete the artifact from the
artifact store.
"""
from zenml.client import Client
from zenml.models import ArtifactVersionResponse
artifact_version = self.get_artifact(name, version)
if isinstance(artifact_version, ArtifactVersionResponse):
client = Client()
client.delete_model_version_artifact_link(
model_version_id=self.id,
artifact_version_id=artifact_version.id,
)
if not only_link:
client.delete_artifact_version(
name_id_or_prefix=artifact_version.id,
delete_metadata=delete_metadata,
delete_from_artifact_store=delete_from_artifact_store,
)
get_artifact(self, name, version=None)
Get the artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the artifact to retrieve. |
required |
version |
Optional[str] |
The version of the artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the artifact linked to this model version.
Args:
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the artifact or placeholder in the design time
of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_artifact(
name=name,
version=version,
)
get_data_artifact(self, name, version=None)
Get the data artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the data artifact to retrieve. |
required |
version |
Optional[str] |
The version of the data artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the data artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the data artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
)
get_deployment_artifact(self, name, version=None)
Get the deployment artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the deployment artifact to retrieve. |
required |
version |
Optional[str] |
The version of the deployment artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the deployment artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_deployment_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the deployment artifact linked to this model version.
Args:
name: The name of the deployment artifact to retrieve.
version: The version of the deployment artifact to retrieve (None
for latest/non-versioned)
Returns:
Specific version of the deployment artifact or placeholder in the
design time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_deployment_artifact(
name=name,
version=version,
)
get_model_artifact(self, name, version=None)
Get the model artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model artifact to retrieve. |
required |
version |
Optional[str] |
The version of the model artifact to retrieve (None for latest/non-versioned) |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactVersionResponse] |
Specific version of the model artifact or placeholder in the design time of the pipeline. |
Source code in zenml/model/model.py
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for
latest/non-versioned)
Returns:
Specific version of the model artifact or placeholder in the design
time of the pipeline.
"""
if lazy := self._lazy_artifact_get(name, version):
return lazy
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
)
get_pipeline_run(self, name)
Get pipeline run linked to this version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the pipeline run to retrieve. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
PipelineRun as PipelineRunResponse |
Source code in zenml/model/model.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
load_artifact(self, name, version=None)
Load artifact from the Model Control Plane.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Name of the artifact to load. |
required |
version |
Optional[str] |
Version of the artifact to load. |
None |
Returns:
Type | Description |
---|---|
Any |
The loaded artifact. |
Exceptions:
Type | Description |
---|---|
ValueError |
if the model version is not linked to any artifact with the given name and version. |
Source code in zenml/model/model.py
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
"""Load artifact from the Model Control Plane.
Args:
name: Name of the artifact to load.
version: Version of the artifact to load.
Returns:
The loaded artifact.
Raises:
ValueError: if the model version is not linked to any artifact with
the given name and version.
"""
from zenml.artifacts.utils import load_artifact
from zenml.models import ArtifactVersionResponse
artifact = self.get_artifact(name=name, version=version)
if not isinstance(artifact, ArtifactVersionResponse):
raise ValueError(
f"Version {self.version} of model {self.name} does not have "
f"an artifact with name {name} and version {version}."
)
return load_artifact(artifact.id, str(artifact.version))
log_metadata(self, metadata)
Log model version metadata.
This function can be used to log metadata for current model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
Source code in zenml/model/model.py
def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.
This function can be used to log metadata for current model version.
Args:
metadata: The metadata to log.
"""
from zenml.client import Client
response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)
model_post_init(/, self, context)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
BaseModel |
The BaseModel instance. |
required |
context |
Any |
The context. |
required |
Source code in zenml/model/model.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
set_stage(self, stage, force=False)
Sets this Model to a desired stage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Union[str, zenml.enums.ModelStages] |
the target stage for model version. |
required |
force |
bool |
whether to force archiving of current model version in target stage or raise. |
False |
Source code in zenml/model/model.py
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> None:
"""Sets this Model to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in
target stage or raise.
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)
model_version
DEPRECATED, use from zenml import Model
instead.
ModelVersion (Model)
DEPRECATED, use from zenml import Model
instead.
Source code in zenml/model/model_version.py
class ModelVersion(Model):
"""DEPRECATED, use `from zenml import Model` instead."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""DEPRECATED, use `from zenml import Model` instead.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
logger.warning(
"`ModelVersion` is deprecated. Please use `Model` instead."
)
super().__init__(*args, **kwargs)
__init__(self, *args, **kwargs)
special
DEPRECATED, use from zenml import Model
instead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Variable length argument list. |
() |
**kwargs |
Any |
Arbitrary keyword arguments. |
{} |
Source code in zenml/model/model_version.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""DEPRECATED, use `from zenml import Model` instead.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
logger.warning(
"`ModelVersion` is deprecated. Please use `Model` instead."
)
super().__init__(*args, **kwargs)
model_post_init(/, self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.
Source code in zenml/model/model_version.py
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, context)
original_model_post_init(self, context)
utils
Utility functions for linking step outputs to model versions.
link_artifact_config_to_model(artifact_config, artifact_version_id, model=None)
Link an artifact config to its model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_config |
ArtifactConfig |
The artifact config to link. |
required |
artifact_version_id |
UUID |
The ID of the artifact to link. |
required |
model |
Optional[Model] |
The model version from the step or pipeline context. |
None |
Source code in zenml/model/utils.py
def link_artifact_config_to_model(
artifact_config: ArtifactConfig,
artifact_version_id: UUID,
model: Optional["Model"] = None,
) -> None:
"""Link an artifact config to its model version.
Args:
artifact_config: The artifact config to link.
artifact_version_id: The ID of the artifact to link.
model: The model version from the step or pipeline context.
"""
client = Client()
# If the artifact config specifies a model itself then always use that
if artifact_config.model_name is not None:
from zenml.model.model import Model
model = Model(
name=artifact_config.model_name,
version=artifact_config.model_version,
)
if model:
logger.debug(
f"Linking artifact `{artifact_config.name}` to model "
f"`{model.name}` version `{model.version}` using config "
f"`{artifact_config}`."
)
request = ModelVersionArtifactRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
artifact_version=artifact_version_id,
model=model.model_id,
model_version=model.id,
is_model_artifact=artifact_config.is_model_artifact,
is_deployment_artifact=artifact_config.is_deployment_artifact,
)
client.zen_store.create_model_version_artifact_link(request)
link_artifact_to_model(artifact_version_id, model=None, is_model_artifact=False, is_deployment_artifact=False)
Link the artifact to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_id |
UUID |
The ID of the artifact version. |
required |
model |
Optional[Model] |
The model to link to. |
None |
is_model_artifact |
bool |
Whether the artifact is a model artifact. |
False |
is_deployment_artifact |
bool |
Whether the artifact is a deployment artifact. |
False |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If called outside of a step. |
Source code in zenml/model/utils.py
def link_artifact_to_model(
artifact_version_id: UUID,
model: Optional["Model"] = None,
is_model_artifact: bool = False,
is_deployment_artifact: bool = False,
) -> None:
"""Link the artifact to the model.
Args:
artifact_version_id: The ID of the artifact version.
model: The model to link to.
is_model_artifact: Whether the artifact is a model artifact.
is_deployment_artifact: Whether the artifact is a deployment artifact.
Raises:
RuntimeError: If called outside of a step.
"""
if not model:
is_issue = False
try:
step_context = get_step_context()
model = step_context.model
except StepContextError:
is_issue = True
if model is None or is_issue:
raise RuntimeError(
"`link_artifact_to_model` called without `model` parameter "
"and configured model context cannot be identified. Consider "
"passing the `model` explicitly or configuring it in "
"@step or @pipeline decorator."
)
link_artifact_config_to_model(
artifact_config=ArtifactConfig(
is_model_artifact=is_model_artifact,
is_deployment_artifact=is_deployment_artifact,
),
artifact_version_id=artifact_version_id,
model=model,
)
link_service_to_model(service_id, model=None, model_version_id=None)
Links a service to a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_id |
UUID |
The ID of the service to link to the model. |
required |
model |
Optional[Model] |
The model to link the service to. |
None |
model_version_id |
Optional[uuid.UUID] |
The ID of the model version to link the service to. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no model is provided and the model context cannot be identified. |
Source code in zenml/model/utils.py
def link_service_to_model(
service_id: UUID,
model: Optional["Model"] = None,
model_version_id: Optional[UUID] = None,
) -> None:
"""Links a service to a model.
Args:
service_id: The ID of the service to link to the model.
model: The model to link the service to.
model_version_id: The ID of the model version to link the service to.
Raises:
RuntimeError: If no model is provided and the model context cannot be
identified.
"""
client = Client()
# If no model is provided, try to get it from the context
if not model and not model_version_id:
is_issue = False
try:
step_context = get_step_context()
model = step_context.model
except StepContextError:
is_issue = True
if model is None or is_issue:
raise RuntimeError(
"`link_service_to_model` called without `model` parameter "
"and configured model context cannot be identified. Consider "
"passing the `model` explicitly or configuring it in "
"@step or @pipeline decorator."
)
model_version_id = (
model_version_id or model._get_or_create_model_version().id
if model
else None
)
update_service = ServiceUpdate(model_version_id=model_version_id)
client.zen_store.update_service(
service_id=service_id, update=update_service
)
link_step_artifacts_to_model(artifact_version_ids)
Links the output artifacts of a step to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_version_ids |
Dict[str, uuid.UUID] |
The IDs of the published output artifacts. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If called outside of a step. |
Source code in zenml/model/utils.py
def link_step_artifacts_to_model(
artifact_version_ids: Dict[str, UUID],
) -> None:
"""Links the output artifacts of a step to the model.
Args:
artifact_version_ids: The IDs of the published output artifacts.
Raises:
RuntimeError: If called outside of a step.
"""
try:
step_context = get_step_context()
except StepContextError:
raise RuntimeError(
"`link_step_artifacts_to_model` can only be called from within a "
"step."
)
try:
model = step_context.model
except StepContextError:
model = None
logger.debug("No model context found, unable to auto-link artifacts.")
for artifact_name, artifact_version_id in artifact_version_ids.items():
artifact_config = step_context._get_output(
artifact_name
).artifact_config
if artifact_config is None and model is not None:
artifact_config = ArtifactConfig(name=artifact_name)
if artifact_config:
link_artifact_config_to_model(
artifact_config=artifact_config,
artifact_version_id=artifact_version_id,
model=model,
)
log_model_metadata(metadata, model_name=None, model_version=None)
Log model version metadata.
This function can be used to log metadata for existing model versions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
model_name |
Optional[str] |
The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
model_version |
Union[zenml.enums.ModelStages, str, int] |
The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If no model name/version is provided and the function is not
called inside a step with configured |
Source code in zenml/model/utils.py
def log_model_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.
This function can be used to log metadata for existing model versions.
Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model` in decorator.
Raises:
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model` in decorator.
"""
if model_name and model_version:
from zenml import Model
mv = Model(name=model_name, version=model_version)
else:
try:
step_context = get_step_context()
except RuntimeError:
raise ValueError(
"Model name and version must be provided unless the function is "
"called inside a step with configured `model` in decorator."
)
mv = step_context.model
mv.log_metadata(metadata)
log_model_version_metadata(metadata, model_name=None, model_version=None)
Log model version metadata.
This function can be used to log metadata for existing model versions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
model_name |
Optional[str] |
The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
model_version |
Union[zenml.enums.ModelStages, str, int] |
The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
|
None |
Source code in zenml/model/utils.py
def log_model_version_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.
This function can be used to log metadata for existing model versions.
Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model` in decorator.
"""
logger.warning(
"`log_model_version_metadata` is deprecated. Please use "
"`log_model_metadata` instead."
)
log_model_metadata(
metadata=metadata, model_name=model_name, model_version=model_version
)