Model
zenml.model
special
Initialization of ZenML model. ZenML model support Model Control Plane feature.
artifact_config
Artifact Config classes to support Model Control Plane feature.
DataArtifactConfig (BaseModel)
pydantic-model
Used to link a data artifact to the model version.
model_name: The name of the model to link data artifact to. !!! model_version "The identifier of the model version to link data artifact to." It can be exact version ("23"), exact version number (42), stage (ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions.
Source code in zenml/model/artifact_config.py
class DataArtifactConfig(BaseModel):
"""Used to link a data artifact to the model version.
model_name: The name of the model to link data artifact to.
model_version: The identifier of the model version to link data artifact to.
It can be exact version ("23"), exact version number (42), stage
(ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version.
model_stage: The stage of the model version to link artifact to.
artifact_name: The override name of a link instead of an artifact name.
overwrite: Whether to overwrite an existing link or create new versions.
"""
model_name: Optional[str]
model_version: Optional[Union[ModelStages, str, int]]
artifact_name: Optional[str]
overwrite: bool = False
_pipeline_name: str = PrivateAttr()
_step_name: str = PrivateAttr()
IS_MODEL_ARTIFACT: ClassVar[bool] = False
IS_ENDPOINT_ARTIFACT: ClassVar[bool] = False
@root_validator
def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
model_name = values.get("model_name", None)
if model_name and values.get("model_version", None) is None:
raise ValueError(
f"Creation of new model version from `{cls}` is not allowed. "
"Please either keep `model_name` and `model_version` both "
"`None` to get the model version from the step context or "
"specify both at the same time. You can use `ModelStages.LATEST` "
"as `model_version` when latest model version is desired."
)
return values
class Config:
"""Config class for ArtifactConfig."""
smart_union = True
@property
def _model_version(self) -> "ModelVersion":
"""Property that returns the model version.
Returns:
ModelVersion: The model version.
Raises:
RuntimeError: If model version cannot be acquired from @step
or @pipeline or built on the fly from fields of this class.
"""
try:
model_version = get_step_context().model_version
except (StepContextError, RuntimeError):
model_version = None
# Check if a specific model name is provided and it doesn't match the context name
if (self.model_name is not None) and (
model_version is None or model_version.name != self.model_name
):
# Create a new ModelVersion instance with the provided model name and version
from zenml.model.model_version import ModelVersion
on_the_fly_config = ModelVersion(
name=self.model_name,
version=self.model_version,
)
return on_the_fly_config
if model_version is None:
raise RuntimeError(
"No model version configuration found in @step or @pipeline. "
"You can configure model version inside ArtifactConfig as well, but "
"`model_name` and `model_version` must be provided."
)
# Return the model from the context
return model_version
def _link_to_model_version(
self,
artifact_uuid: UUID,
model_version: "ModelVersion",
is_model_artifact: bool = False,
is_endpoint_artifact: bool = False,
) -> None:
"""Link artifact to the model version.
This method is used on exit from the step context to link artifact to the model version.
Args:
artifact_uuid: The UUID of the artifact to link.
model_version: The model version from caller.
is_model_artifact: Whether the artifact is a model artifact. Defaults to False.
is_endpoint_artifact: Whether the artifact is an endpoint artifact. Defaults to False.
"""
from zenml.client import Client
from zenml.models.model_models import (
ModelVersionArtifactFilterModel,
ModelVersionArtifactRequestModel,
)
# Create a ZenML client
client = Client()
artifact_name = self.artifact_name
if artifact_name is None:
artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid)
artifact_name = artifact.name
# Create a request model for the model version artifact link
request = ModelVersionArtifactRequestModel(
user=client.active_user.id,
workspace=client.active_workspace.id,
name=artifact_name,
artifact=artifact_uuid,
model=model_version.model_id,
model_version=model_version.id,
is_model_artifact=is_model_artifact,
is_endpoint_artifact=is_endpoint_artifact,
overwrite=self.overwrite,
pipeline_name=self._pipeline_name,
step_name=self._step_name,
)
# Create the model version artifact link using the ZenML client
existing_links = client.list_model_version_artifact_links(
model_version_id=model_version.id,
model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
user_id=client.active_user.id,
workspace_id=client.active_workspace.id,
name=artifact_name,
only_data_artifacts=not (
is_model_artifact or is_endpoint_artifact
),
only_endpoint_artifacts=is_endpoint_artifact,
only_model_artifacts=is_model_artifact,
pipeline_name=self._pipeline_name,
step_name=self._step_name,
),
)
if len(existing_links):
if self.overwrite:
# delete all model version artifact links by name
logger.warning(
f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
)
client.zen_store.delete_model_version_artifact_link(
model_version_id=model_version.id,
model_version_artifact_link_name_or_id=artifact_name,
)
else:
logger.info(
f"Artifact link `{artifact_name}` already exists, adding new version."
)
client.zen_store.create_model_version_artifact_link(request)
def link_to_model(
self, artifact_uuid: UUID, model_version: "ModelVersion"
) -> None:
"""Link artifact to the model version.
Args:
artifact_uuid: The UUID of the artifact to link.
model_version: The model version from caller.
"""
self._link_to_model_version(
artifact_uuid,
model_version=model_version,
is_model_artifact=self.IS_MODEL_ARTIFACT,
is_endpoint_artifact=self.IS_ENDPOINT_ARTIFACT,
)
Config
Config class for ArtifactConfig.
Source code in zenml/model/artifact_config.py
class Config:
"""Config class for ArtifactConfig."""
smart_union = True
link_to_model(self, artifact_uuid, model_version)
Link artifact to the model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_uuid |
UUID |
The UUID of the artifact to link. |
required |
model_version |
ModelVersion |
The model version from caller. |
required |
Source code in zenml/model/artifact_config.py
def link_to_model(
self, artifact_uuid: UUID, model_version: "ModelVersion"
) -> None:
"""Link artifact to the model version.
Args:
artifact_uuid: The UUID of the artifact to link.
model_version: The model version from caller.
"""
self._link_to_model_version(
artifact_uuid,
model_version=model_version,
is_model_artifact=self.IS_MODEL_ARTIFACT,
is_endpoint_artifact=self.IS_ENDPOINT_ARTIFACT,
)
EndpointArtifactConfig (DataArtifactConfig)
pydantic-model
Used to link an endpoint artifact to the model version.
Source code in zenml/model/artifact_config.py
class EndpointArtifactConfig(DataArtifactConfig):
"""Used to link an endpoint artifact to the model version."""
IS_ENDPOINT_ARTIFACT = True
ModelArtifactConfig (DataArtifactConfig)
pydantic-model
Used to link a model artifact to the model version.
save_to_model_registry: Whether to save the model artifact to the model registry.
Source code in zenml/model/artifact_config.py
class ModelArtifactConfig(DataArtifactConfig):
"""Used to link a model artifact to the model version.
save_to_model_registry: Whether to save the model artifact to the model registry.
"""
save_to_model_registry: bool = True
IS_MODEL_ARTIFACT = True
link_output_to_model
Utility functions for linking step outputs to model versions.
link_output_to_model(artifact_config, output_name=None)
Log artifact metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
The output name of the artifact to log metadata for. Can be omitted if there is only one output artifact. |
None |
artifact_config |
DataArtifactConfig |
The ArtifactConfig of how to link this output. |
required |
Source code in zenml/model/link_output_to_model.py
def link_output_to_model(
artifact_config: "DataArtifactConfig",
output_name: Optional[str] = None,
) -> None:
"""Log artifact metadata.
Args:
output_name: The output name of the artifact to log metadata for. Can
be omitted if there is only one output artifact.
artifact_config: The ArtifactConfig of how to link this output.
"""
from zenml.new.steps.step_context import get_step_context
step_context = get_step_context()
step_context._set_artifact_config(
output_name=output_name, artifact_config=artifact_config
)
model_version
ModelVersion user facing interface to pass into pipeline or step.
ModelVersion (BaseModel)
pydantic-model
ModelVersion class to pass into pipeline or step to set it into a model context.
name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage. If skipped new model version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack.
Source code in zenml/model/model_version.py
class ModelVersion(BaseModel):
"""ModelVersion class to pass into pipeline or step to set it into a model context.
name: The name of the model.
license: The license under which the model is created.
description: The description of the model.
audience: The target audience of the model.
use_cases: The use cases of the model.
limitations: The known limitations of the model.
trade_offs: The tradeoffs of the model.
ethics: The ethical implications of the model.
tags: Tags associated with the model.
version: The model version name, number or stage is optional and points model context
to a specific version/stage. If skipped new model version will be created.
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
"""
name: str
license: Optional[str]
description: Optional[str]
audience: Optional[str]
use_cases: Optional[str]
limitations: Optional[str]
trade_offs: Optional[str]
ethics: Optional[str]
tags: Optional[List[str]]
version: Optional[Union[ModelStages, int, str]]
save_models_to_registry: bool = True
suppress_class_validation_warnings: bool = False
was_created_in_this_run: bool = False
_model_id: UUID = PrivateAttr(None)
_id: UUID = PrivateAttr(None)
_number: int = PrivateAttr(None)
#########################
# Public methods #
#########################
@property
def id(self) -> UUID:
"""Get version id from the Model Control Plane.
Returns:
ID of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
if self._id is None:
try:
self._get_or_create_model_version()
except RuntimeError:
logger.info(
f"Model version `{self.version}` doesn't exist "
"and cannot be fetched from the Model Control Plane."
)
return self._id
@property
def model_id(self) -> UUID:
"""Get model id from the Model Control Plane.
Returns:
The UUID of the model containing this model version.
"""
if self._model_id is None:
self._get_or_create_model()
return self._model_id
@property
def number(self) -> int:
"""Get version number from the Model Control Plane.
Returns:
Number of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
if self._number is None:
try:
self._get_or_create_model_version()
except RuntimeError:
logger.info(
f"Model version `{self.version}` doesn't exist "
"and cannot be fetched from the Model Control Plane."
)
return self._number
@property
def stage(self) -> Optional[ModelStages]:
"""Get version stage from the Model Control Plane.
Returns:
Stage of the model version or None, if model version
doesn't exist and can only be read given current
config (you used stage name or number as
a version name).
"""
try:
stage = self._get_or_create_model_version().stage
if stage:
return ModelStages(stage)
except RuntimeError:
logger.info(
f"Model version `{self.version}` doesn't exist "
"and cannot be fetched from the Model Control Plane."
)
return None
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
pipeline_name: Optional[str] = None,
step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for latest/non-versioned)
pipeline_name: The name of the pipeline-generated the model artifact.
step_name: The name of the step-generated the model artifact.
Returns:
Specific version of the model artifact or None
"""
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
pipeline_name=pipeline_name,
step_name=step_name,
)
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
pipeline_name: Optional[str] = None,
step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for latest/non-versioned)
pipeline_name: The name of the pipeline generated the data artifact.
step_name: The name of the step generated the data artifact.
Returns:
Specific version of the data artifact or None
"""
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
pipeline_name=pipeline_name,
step_name=step_name,
)
def get_endpoint_artifact(
self,
name: str,
version: Optional[str] = None,
pipeline_name: Optional[str] = None,
step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
"""Get the endpoint artifact linked to this model version.
Args:
name: The name of the endpoint artifact to retrieve.
version: The version of the endpoint artifact to retrieve (None for latest/non-versioned)
pipeline_name: The name of the pipeline generated the endpoint artifact.
step_name: The name of the step generated the endpoint artifact.
Returns:
Specific version of the endpoint artifact or None
"""
return self._get_or_create_model_version().get_endpoint_artifact(
name=name,
version=version,
pipeline_name=pipeline_name,
step_name=step_name,
)
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> "ModelVersion":
"""Sets this Model Version to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in target stage or raise.
Returns:
Updated Model Version object.
"""
return self._get_or_create_model_version().set_stage(
stage=stage, force=force
)
#########################
# Internal methods #
#########################
class Config:
"""Config class."""
smart_union = True
def __eq__(self, other: object) -> bool:
"""Check two ModelVersions for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, ModelVersion):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
@root_validator(pre=True)
def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate all in one.
Args:
values: Dict of values.
Returns:
Dict of validated values.
"""
suppress_class_validation_warnings = values.get(
"suppress_class_validation_warnings", False
)
version = values.get("version", None)
if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
):
logger.info(
f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage."
)
if str(version).isnumeric() and not suppress_class_validation_warnings:
logger.info(
f"`version` `{version}` is numeric and will be fetched using version number."
)
values["suppress_class_validation_warnings"] = True
return values
def _validate_config_in_runtime(self) -> None:
"""Validate that config doesn't conflict with runtime environment."""
self._get_or_create_model_version()
def _get_or_create_model(self) -> "ModelResponseModel":
"""This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
The model based on configuration.
"""
from zenml.client import Client
from zenml.models.model_models import ModelRequestModel
zenml_client = Client()
try:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
except KeyError:
model_request = ModelRequestModel(
name=self.name,
license=self.license,
description=self.description,
audience=self.audience,
use_cases=self.use_cases,
limitations=self.limitations,
trade_offs=self.trade_offs,
ethics=self.ethics,
tags=self.tags,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
)
model_request = ModelRequestModel.parse_obj(model_request)
try:
model = zenml_client.zen_store.create_model(
model=model_request
)
logger.info(f"New model `{self.name}` was created implicitly.")
except EntityExistsError:
# this is backup logic, if model was created somehow in between get and create calls
pass
finally:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
self._model_id = model.id
return model
def _get_model_version(self) -> "ModelVersionResponseModel":
"""This method gets a model version from Model Control Plane.
Returns:
The model version based on configuration.
"""
from zenml.client import Client
zenml_client = Client()
mv = zenml_client._get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
if not self._id:
self._id = mv.id
return mv
def _get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model is fetched. Model
name is controlled by the `name` parameter.
Model Version returned by this method is resolved based on model version:
- If `version` is None, a new model version is created, if not created by other steps in same run.
- If `version` is not None a model version will be fetched based on the version:
- If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.
Returns:
The model version based on configuration.
Raises:
RuntimeError: if the model version needs to be created, but provided name is reserved
"""
from zenml.client import Client
from zenml.models.model_models import ModelVersionRequestModel
model = self._get_or_create_model()
zenml_client = Client()
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
if not self.version:
try:
from zenml import get_step_context
context = get_step_context()
except RuntimeError:
pass
else:
# if inside a step context we loop over all
# model version configuration to find, if the
# model version for current model was already
# created in the current run, not to create
# new model versions
pipeline_mv = context.pipeline_run.config.model_version
if (
pipeline_mv
and pipeline_mv.was_created_in_this_run
and pipeline_mv.name == self.name
and pipeline_mv.version is not None
):
self.version = pipeline_mv.version
else:
for step in context.pipeline_run.steps.values():
step_mv = step.config.model_version
if (
step_mv
and step_mv.was_created_in_this_run
and step_mv.name == self.name
and step_mv.version is not None
):
self.version = step_mv.version
break
if self.version:
model_version = self._get_model_version()
else:
raise KeyError
except KeyError:
if (
self.version
and str(self.version).lower() in ModelStages.values()
):
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"it matches one of the possible model version stages. If you "
"are aiming to fetch model version by stage, check if the "
"model version in given stage exists. It might be missing, if "
"the pipeline promoting model version to this stage failed,"
" as an example. You can explore model versions using "
f"`zenml model version list {self.name}` CLI command."
)
if str(self.version).isnumeric():
raise RuntimeError(
f"Cannot create a model version named {str(self.version)} as "
"numeric model version names are reserved. If you "
"are aiming to fetch model version by number, check if the "
"model version with given number exists. It might be missing, if "
"the pipeline creating model version failed,"
" as an example. You can explore model versions using "
f"`zenml model version list {self.name}` CLI command."
)
model_version = zenml_client.zen_store.create_model_version(
model_version=mv_request
)
self.version = model_version.name
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
self._id = model_version.id
self._model_id = model_version.model.id
self._number = model_version.number
return model_version
def _merge(self, model_version: "ModelVersion") -> None:
self.license = self.license or model_version.license
self.description = self.description or model_version.description
self.audience = self.audience or model_version.audience
self.use_cases = self.use_cases or model_version.use_cases
self.limitations = self.limitations or model_version.limitations
self.trade_offs = self.trade_offs or model_version.trade_offs
self.ethics = self.ethics or model_version.ethics
if model_version.tags is not None:
self.tags = list(
{t for t in self.tags or []}.union(set(model_version.tags))
)
def __hash__(self) -> int:
"""Get hash of the `ModelVersion`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
id: UUID
property
readonly
Get version id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
model_id: UUID
property
readonly
Get model id from the Model Control Plane.
Returns:
Type | Description |
---|---|
UUID |
The UUID of the model containing this model version. |
number: int
property
readonly
Get version number from the Model Control Plane.
Returns:
Type | Description |
---|---|
int |
Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
stage: Optional[zenml.enums.ModelStages]
property
readonly
Get version stage from the Model Control Plane.
Returns:
Type | Description |
---|---|
Optional[zenml.enums.ModelStages] |
Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name). |
Config
Config class.
Source code in zenml/model/model_version.py
class Config:
"""Config class."""
smart_union = True
__eq__(self, other)
special
Check two ModelVersions for equality.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
object |
object to compare with |
required |
Returns:
Type | Description |
---|---|
bool |
True, if equal, False otherwise. |
Source code in zenml/model/model_version.py
def __eq__(self, other: object) -> bool:
"""Check two ModelVersions for equality.
Args:
other: object to compare with
Returns:
True, if equal, False otherwise.
"""
if not isinstance(other, ModelVersion):
return NotImplemented
if self.name != other.name:
return False
if self.name == other.name and self.version == other.version:
return True
self_mv = self._get_or_create_model_version()
other_mv = other._get_or_create_model_version()
return self_mv.id == other_mv.id
__hash__(self)
special
Get hash of the ModelVersion
.
Returns:
Type | Description |
---|---|
int |
Hash function results |
Source code in zenml/model/model_version.py
def __hash__(self) -> int:
"""Get hash of the `ModelVersion`.
Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
)
)
)
)
get_data_artifact(self, name, version=None, pipeline_name=None, step_name=None)
Get the data artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the data artifact to retrieve. |
required |
version |
Optional[str] |
The version of the data artifact to retrieve (None for latest/non-versioned) |
None |
pipeline_name |
Optional[str] |
The name of the pipeline generated the data artifact. |
None |
step_name |
Optional[str] |
The name of the step generated the data artifact. |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactResponse] |
Specific version of the data artifact or None |
Source code in zenml/model/model_version.py
def get_data_artifact(
self,
name: str,
version: Optional[str] = None,
pipeline_name: Optional[str] = None,
step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
"""Get the data artifact linked to this model version.
Args:
name: The name of the data artifact to retrieve.
version: The version of the data artifact to retrieve (None for latest/non-versioned)
pipeline_name: The name of the pipeline generated the data artifact.
step_name: The name of the step generated the data artifact.
Returns:
Specific version of the data artifact or None
"""
return self._get_or_create_model_version().get_data_artifact(
name=name,
version=version,
pipeline_name=pipeline_name,
step_name=step_name,
)
get_endpoint_artifact(self, name, version=None, pipeline_name=None, step_name=None)
Get the endpoint artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the endpoint artifact to retrieve. |
required |
version |
Optional[str] |
The version of the endpoint artifact to retrieve (None for latest/non-versioned) |
None |
pipeline_name |
Optional[str] |
The name of the pipeline generated the endpoint artifact. |
None |
step_name |
Optional[str] |
The name of the step generated the endpoint artifact. |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactResponse] |
Specific version of the endpoint artifact or None |
Source code in zenml/model/model_version.py
def get_endpoint_artifact(
self,
name: str,
version: Optional[str] = None,
pipeline_name: Optional[str] = None,
step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
"""Get the endpoint artifact linked to this model version.
Args:
name: The name of the endpoint artifact to retrieve.
version: The version of the endpoint artifact to retrieve (None for latest/non-versioned)
pipeline_name: The name of the pipeline generated the endpoint artifact.
step_name: The name of the step generated the endpoint artifact.
Returns:
Specific version of the endpoint artifact or None
"""
return self._get_or_create_model_version().get_endpoint_artifact(
name=name,
version=version,
pipeline_name=pipeline_name,
step_name=step_name,
)
get_model_artifact(self, name, version=None, pipeline_name=None, step_name=None)
Get the model artifact linked to this model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model artifact to retrieve. |
required |
version |
Optional[str] |
The version of the model artifact to retrieve (None for latest/non-versioned) |
None |
pipeline_name |
Optional[str] |
The name of the pipeline-generated the model artifact. |
None |
step_name |
Optional[str] |
The name of the step-generated the model artifact. |
None |
Returns:
Type | Description |
---|---|
Optional[ArtifactResponse] |
Specific version of the model artifact or None |
Source code in zenml/model/model_version.py
def get_model_artifact(
self,
name: str,
version: Optional[str] = None,
pipeline_name: Optional[str] = None,
step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
"""Get the model artifact linked to this model version.
Args:
name: The name of the model artifact to retrieve.
version: The version of the model artifact to retrieve (None for latest/non-versioned)
pipeline_name: The name of the pipeline-generated the model artifact.
step_name: The name of the step-generated the model artifact.
Returns:
Specific version of the model artifact or None
"""
return self._get_or_create_model_version().get_model_artifact(
name=name,
version=version,
pipeline_name=pipeline_name,
step_name=step_name,
)
get_pipeline_run(self, name)
Get pipeline run linked to this version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the pipeline run to retrieve. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
PipelineRun as PipelineRunResponse |
Source code in zenml/model/model_version.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Args:
name: The name of the pipeline run to retrieve.
Returns:
PipelineRun as PipelineRunResponse
"""
return self._get_or_create_model_version().get_pipeline_run(name=name)
set_stage(self, stage, force=False)
Sets this Model Version to a desired stage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Union[str, zenml.enums.ModelStages] |
the target stage for model version. |
required |
force |
bool |
whether to force archiving of current model version in target stage or raise. |
False |
Returns:
Type | Description |
---|---|
ModelVersion |
Updated Model Version object. |
Source code in zenml/model/model_version.py
def set_stage(
self, stage: Union[str, ModelStages], force: bool = False
) -> "ModelVersion":
"""Sets this Model Version to a desired stage.
Args:
stage: the target stage for model version.
force: whether to force archiving of current model version in target stage or raise.
Returns:
Updated Model Version object.
"""
return self._get_or_create_model_version().set_stage(
stage=stage, force=force
)