Model
zenml.model
special
Initialization of ZenML model. ZenML model support Model Control Plane feature.
artifact_config
Artifact Config classes to support Model Control Plane feature.
ArtifactConfig (BaseModel)
pydantic-model
Used to link a generic Artifact to the model version.
model_name: The name of the model to link artifact to. !!! model_version "The identifier of the model version to link artifact to." It can be exact version ("23"), exact version number (42), stage (ModelStages.PRODUCTION) or None for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions.
Source code in zenml/model/artifact_config.py
class ArtifactConfig(BaseModel):
"""Used to link a generic Artifact to the model version.
model_name: The name of the model to link artifact to.
model_version: The identifier of the model version to link artifact to.
It can be exact version ("23"), exact version number (42), stage
(ModelStages.PRODUCTION) or None for the latest version.
model_stage: The stage of the model version to link artifact to.
artifact_name: The override name of a link instead of an artifact name.
overwrite: Whether to overwrite an existing link or create new versions.
"""
model_name: Optional[str]
model_version: Optional[Union[ModelStages, str, int]]
artifact_name: Optional[str]
overwrite: bool = False
_pipeline_name: str = PrivateAttr()
_step_name: str = PrivateAttr()
IS_MODEL_ARTIFACT: ClassVar[bool] = False
IS_DEPLOYMENT_ARTIFACT: ClassVar[bool] = False
class Config:
"""Config class for ArtifactConfig."""
smart_union = True
@property
def _model_config(self) -> "ModelConfig":
"""Property that returns the model configuration.
Returns:
ModelConfig: The model configuration.
Raises:
RuntimeError: If model configuration cannot be acquired from @step
or @pipeline or built on the fly from fields of this class.
"""
try:
model_config = get_step_context().model_config
except StepContextError:
model_config = None
# Check if a specific model name is provided and it doesn't match the context name
if (self.model_name is not None) and (
model_config is None or model_config.name != self.model_name
):
# Create a new ModelConfig instance with the provided model name and version
from zenml.model.model_config import ModelConfig
on_the_fly_config = ModelConfig(
name=self.model_name,
version=self.model_version,
create_new_model_version=False,
suppress_warnings=True,
)
return on_the_fly_config
if model_config is None:
raise RuntimeError(
"No model configuration found in @step or @pipeline. "
"You can configure ModelConfig inside ArtifactConfig as well, but "
"`model_name` and `model_version` must be provided."
)
# Return the model from the context
return model_config
@property
def _model(self) -> "ModelResponseModel":
"""Get the `ModelResponseModel`.
Returns:
ModelResponseModel: The fetched or created model.
"""
return self._model_config.get_or_create_model()
@property
def _model_version(self) -> "ModelVersionResponseModel":
"""Get the `ModelVersionResponseModel`.
Returns:
ModelVersionResponseModel: The model version.
"""
return self._model_config.get_or_create_model_version()
def _link_to_model_version(
self,
artifact_uuid: UUID,
is_model_object: bool = False,
is_deployment: bool = False,
) -> None:
"""Link artifact to the model version.
This method is used on exit from the step context to link artifact to the model version.
Args:
artifact_uuid: The UUID of the artifact to link.
is_model_object: Whether the artifact is a model object. Defaults to False.
is_deployment: Whether the artifact is a deployment. Defaults to False.
"""
from zenml.client import Client
# Create a ZenML client
client = Client()
artifact_name = self.artifact_name
if artifact_name is None:
artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid)
artifact_name = artifact.name
# Create a request model for the model version artifact link
request = ModelVersionArtifactRequestModel(
user=client.active_user.id,
workspace=client.active_workspace.id,
name=artifact_name,
artifact=artifact_uuid,
model=self._model.id,
model_version=self._model_version.id,
is_model_object=is_model_object,
is_deployment=is_deployment,
overwrite=self.overwrite,
pipeline_name=self._pipeline_name,
step_name=self._step_name,
)
# Create the model version artifact link using the ZenML client
existing_links = client.list_model_version_artifact_links(
ModelVersionArtifactFilterModel(
user_id=client.active_user.id,
workspace_id=client.active_workspace.id,
name=artifact_name,
model_id=self._model.id,
model_version_id=self._model_version.id,
only_artifacts=not (is_model_object or is_deployment),
only_deployments=is_deployment,
only_model_objects=is_model_object,
)
)
if len(existing_links):
if self.overwrite:
# delete all model version artifact links by name
logger.warning(
f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
)
client.zen_store.delete_model_version_artifact_link(
model_name_or_id=self._model.id,
model_version_name_or_id=self._model_version.id,
model_version_artifact_link_name_or_id=artifact_name,
)
else:
logger.info(
f"Artifact link `{artifact_name}` already exists, adding new version."
)
client.zen_store.create_model_version_artifact_link(request)
def link_to_model(
self,
artifact_uuid: UUID,
) -> None:
"""Link artifact to the model version.
Args:
artifact_uuid (UUID): The UUID of the artifact to link.
"""
self._link_to_model_version(
artifact_uuid,
is_model_object=self.IS_MODEL_ARTIFACT,
is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
)
Config
Config class for ArtifactConfig.
Source code in zenml/model/artifact_config.py
class Config:
"""Config class for ArtifactConfig."""
smart_union = True
link_to_model(self, artifact_uuid)
Link artifact to the model version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_uuid |
UUID |
The UUID of the artifact to link. |
required |
Source code in zenml/model/artifact_config.py
def link_to_model(
self,
artifact_uuid: UUID,
) -> None:
"""Link artifact to the model version.
Args:
artifact_uuid (UUID): The UUID of the artifact to link.
"""
self._link_to_model_version(
artifact_uuid,
is_model_object=self.IS_MODEL_ARTIFACT,
is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
)
DeploymentArtifactConfig (ArtifactConfig)
pydantic-model
Used to link a Deployment to the model version.
Source code in zenml/model/artifact_config.py
class DeploymentArtifactConfig(ArtifactConfig):
"""Used to link a Deployment to the model version."""
IS_DEPLOYMENT_ARTIFACT = True
ModelArtifactConfig (ArtifactConfig)
pydantic-model
Used to link a Model Object to the model version.
save_to_model_registry: Whether to save the model object to the model registry.
Source code in zenml/model/artifact_config.py
class ModelArtifactConfig(ArtifactConfig):
"""Used to link a Model Object to the model version.
save_to_model_registry: Whether to save the model object to the model registry.
"""
save_to_model_registry: bool = True
IS_MODEL_ARTIFACT = True
link_output_to_model
Utility functions for linking step outputs to model versions.
link_output_to_model(artifact_config, output_name=None)
Log artifact metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
The output name of the artifact to log metadata for. Can be omitted if there is only one output artifact. |
None |
artifact_config |
ArtifactConfig |
The ArtifactConfig of how to link this output. |
required |
Source code in zenml/model/link_output_to_model.py
def link_output_to_model(
artifact_config: "ArtifactConfig",
output_name: Optional[str] = None,
) -> None:
"""Log artifact metadata.
Args:
output_name: The output name of the artifact to log metadata for. Can
be omitted if there is only one output artifact.
artifact_config: The ArtifactConfig of how to link this output.
"""
from zenml.new.steps.step_context import get_step_context
step_context = get_step_context()
step_context._set_artifact_config(
output_name=output_name, artifact_config=artifact_config
)
model_config
ModelConfig user facing interface to pass into pipeline or step.
ModelConfig (ModelConfigModel)
pydantic-model
ModelConfig class to pass into pipeline or step to set it into a model context.
name: The name of the model.
!!! version "The model version name, number or stage is optional and points model context"
to a specific version/stage, if skipped and create_new_model_version
is False -
latest model version will be used.
version_description: The description of the model version.
create_new_model_version: Whether to create a new model version during execution
!!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry,"
if available in active stack.
delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
Source code in zenml/model/model_config.py
class ModelConfig(ModelConfigModel):
"""ModelConfig class to pass into pipeline or step to set it into a model context.
name: The name of the model.
version: The model version name, number or stage is optional and points model context
to a specific version/stage, if skipped and `create_new_model_version` is False -
latest model version will be used.
version_description: The description of the model version.
create_new_model_version: Whether to create a new model version during execution
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
"""
_model: Optional["ModelResponseModel"] = PrivateAttr(default=None)
_model_version: Optional["ModelVersionResponseModel"] = PrivateAttr(
default=None
)
def get_or_create_model(self) -> "ModelResponseModel":
"""This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
The model based on configuration.
"""
if self._model is not None:
return self._model
from zenml.client import Client
from zenml.models.model_models import ModelRequestModel
zenml_client = Client()
try:
self._model = zenml_client.get_model(model_name_or_id=self.name)
except KeyError:
model_request = ModelRequestModel(
name=self.name,
license=self.license,
description=self.description,
audience=self.audience,
use_cases=self.use_cases,
limitations=self.limitations,
trade_offs=self.trade_offs,
ethic=self.ethic,
tags=self.tags,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
)
model_request = ModelRequestModel.parse_obj(model_request)
try:
self._model = zenml_client.create_model(model=model_request)
logger.info(f"New model `{self.name}` was created implicitly.")
except EntityExistsError:
# this is backup logic, if model was created somehow in between get and create calls
self._model = zenml_client.get_model(
model_name_or_id=self.name
)
return self._model
def _create_model_version(
self, model: "ModelResponseModel"
) -> "ModelVersionResponseModel":
"""This method creates a model version for Model Control Plane.
Args:
model: The model containing the model version.
Returns:
The model version based on configuration.
"""
if self._model_version is not None:
return self._model_version
from zenml.client import Client
from zenml.models.model_models import ModelVersionRequestModel
zenml_client = Client()
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.version_description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
self._model_version = mv
except KeyError:
self._model_version = zenml_client.create_model_version(
model_version=mv_request
)
logger.info(f"New model version `{self.version}` was created.")
return self._model_version
def _get_model_version(self) -> "ModelVersionResponseModel":
"""This method gets a model version from Model Control Plane.
Returns:
The model version based on configuration.
"""
if self._model_version is not None:
return self._model_version
from zenml.client import Client
zenml_client = Client()
if self.version is None:
# raise if not found
self._model_version = zenml_client.get_model_version(
model_name_or_id=self.name
)
else:
# by version name or stage or number
# raise if not found
self._model_version = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
return self._model_version
def get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model is fetched. Model
name is controlled by the `name` parameter.
Model Version returned by this method is resolved based on model configuration:
- If there is an existing model version leftover from the previous failed run with
`delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
leftover model version will be reused.
- Otherwise if `create_new_model_version` is True, a new model version is created.
- If `create_new_model_version` is False a model version will be fetched based on the version:
- If `version` is not set, the latest model version will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.
Returns:
The model version based on configuration.
"""
model = self.get_or_create_model()
if self.create_new_model_version:
mv = self._create_model_version(model)
else:
mv = self._get_model_version()
return mv
def _merge_with_config(self, model_config: ModelConfigModel) -> None:
self.license = self.license or model_config.license
self.description = self.description or model_config.description
self.audience = self.audience or model_config.audience
self.use_cases = self.use_cases or model_config.use_cases
self.limitations = self.limitations or model_config.limitations
self.trade_offs = self.trade_offs or model_config.trade_offs
self.ethic = self.ethic or model_config.ethic
if model_config.tags is not None:
self.tags = (self.tags or []) + model_config.tags
self.delete_new_version_on_failure &= (
model_config.delete_new_version_on_failure
)
get_or_create_model(self)
This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
Type | Description |
---|---|
ModelResponseModel |
The model based on configuration. |
Source code in zenml/model/model_config.py
def get_or_create_model(self) -> "ModelResponseModel":
"""This method should get or create a model from Model Control Plane.
New model is created implicitly, if missing, otherwise fetched.
Returns:
The model based on configuration.
"""
if self._model is not None:
return self._model
from zenml.client import Client
from zenml.models.model_models import ModelRequestModel
zenml_client = Client()
try:
self._model = zenml_client.get_model(model_name_or_id=self.name)
except KeyError:
model_request = ModelRequestModel(
name=self.name,
license=self.license,
description=self.description,
audience=self.audience,
use_cases=self.use_cases,
limitations=self.limitations,
trade_offs=self.trade_offs,
ethic=self.ethic,
tags=self.tags,
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
)
model_request = ModelRequestModel.parse_obj(model_request)
try:
self._model = zenml_client.create_model(model=model_request)
logger.info(f"New model `{self.name}` was created implicitly.")
except EntityExistsError:
# this is backup logic, if model was created somehow in between get and create calls
self._model = zenml_client.get_model(
model_name_or_id=self.name
)
return self._model
get_or_create_model_version(self)
This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model is fetched. Model
name is controlled by the name
parameter.
Model Version returned by this method is resolved based on model configuration:
- If there is an existing model version leftover from the previous failed run with
delete_new_version_on_failure
is set to False and create_new_model_version
is True,
leftover model version will be reused.
- Otherwise if create_new_model_version
is True, a new model version is created.
- If create_new_model_version
is False a model version will be fetched based on the version:
- If version
is not set, the latest model version will be fetched.
- If version
is set to a string, the model version with the matching version will be fetched.
- If version
is set to a ModelStage
, the model version with the matching stage will be fetched.
Returns:
Type | Description |
---|---|
ModelVersionResponseModel |
The model version based on configuration. |
Source code in zenml/model/model_config.py
def get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
A new model is created implicitly if missing, otherwise existing model is fetched. Model
name is controlled by the `name` parameter.
Model Version returned by this method is resolved based on model configuration:
- If there is an existing model version leftover from the previous failed run with
`delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
leftover model version will be reused.
- Otherwise if `create_new_model_version` is True, a new model version is created.
- If `create_new_model_version` is False a model version will be fetched based on the version:
- If `version` is not set, the latest model version will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.
Returns:
The model version based on configuration.
"""
model = self.get_or_create_model()
if self.create_new_model_version:
mv = self._create_model_version(model)
else:
mv = self._get_model_version()
return mv