Skip to content

Model

zenml.model special

Initialization of ZenML model. ZenML model support Model Control Plane feature.

artifact_config

Artifact Config classes to support Model Control Plane feature.

ArtifactConfig (BaseModel) pydantic-model

Used to link a generic Artifact to the model version.

model_name: The name of the model to link artifact to. !!! model_version "The identifier of the model version to link artifact to." It can be exact version ("23"), exact version number (42), stage (ModelStages.PRODUCTION) or None for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions.

Source code in zenml/model/artifact_config.py
class ArtifactConfig(BaseModel):
    """Used to link a generic Artifact to the model version.

    model_name: The name of the model to link artifact to.
    model_version: The identifier of the model version to link artifact to.
        It can be exact version ("23"), exact version number (42), stage
        (ModelStages.PRODUCTION) or None for the latest version.
    model_stage: The stage of the model version to link artifact to.
    artifact_name: The override name of a link instead of an artifact name.
    overwrite: Whether to overwrite an existing link or create new versions.
    """

    model_name: Optional[str]
    model_version: Optional[Union[ModelStages, str, int]]
    artifact_name: Optional[str]
    overwrite: bool = False

    _pipeline_name: str = PrivateAttr()
    _step_name: str = PrivateAttr()
    IS_MODEL_ARTIFACT: ClassVar[bool] = False
    IS_DEPLOYMENT_ARTIFACT: ClassVar[bool] = False

    class Config:
        """Config class for ArtifactConfig."""

        smart_union = True

    @property
    def _model_config(self) -> "ModelConfig":
        """Property that returns the model configuration.

        Returns:
            ModelConfig: The model configuration.

        Raises:
            RuntimeError: If model configuration cannot be acquired from @step
                or @pipeline or built on the fly from fields of this class.
        """
        try:
            model_config = get_step_context().model_config
        except (StepContextError, RuntimeError):
            model_config = None
        # Check if a specific model name is provided and it doesn't match the context name
        if (self.model_name is not None) and (
            model_config is None or model_config.name != self.model_name
        ):
            # Create a new ModelConfig instance with the provided model name and version
            from zenml.model.model_config import ModelConfig

            on_the_fly_config = ModelConfig(
                name=self.model_name,
                version=self.model_version,
                create_new_model_version=False,
                suppress_warnings=True,
            )
            return on_the_fly_config

        if model_config is None:
            raise RuntimeError(
                "No model configuration found in @step or @pipeline. "
                "You can configure ModelConfig inside ArtifactConfig as well, but "
                "`model_name` and `model_version` must be provided."
            )
        # Return the model from the context
        return model_config

    @property
    def _model(self) -> "ModelResponseModel":
        """Get the `ModelResponseModel`.

        Returns:
            ModelResponseModel: The fetched or created model.
        """
        return self._model_config.get_or_create_model()

    @property
    def _model_version(self) -> "ModelVersionResponseModel":
        """Get the `ModelVersionResponseModel`.

        Returns:
            ModelVersionResponseModel: The model version.
        """
        return self._model_config.get_or_create_model_version()

    def _link_to_model_version(
        self,
        artifact_uuid: UUID,
        is_model_object: bool = False,
        is_deployment: bool = False,
    ) -> None:
        """Link artifact to the model version.

        This method is used on exit from the step context to link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            is_model_object: Whether the artifact is a model object. Defaults to False.
            is_deployment: Whether the artifact is a deployment. Defaults to False.
        """
        from zenml.client import Client

        # Create a ZenML client
        client = Client()

        artifact_name = self.artifact_name
        if artifact_name is None:
            artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid)
            artifact_name = artifact.name

        # Create a request model for the model version artifact link
        request = ModelVersionArtifactRequestModel(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            name=artifact_name,
            artifact=artifact_uuid,
            model=self._model.id,
            model_version=self._model_version.id,
            is_model_object=is_model_object,
            is_deployment=is_deployment,
            overwrite=self.overwrite,
            pipeline_name=self._pipeline_name,
            step_name=self._step_name,
        )

        # Create the model version artifact link using the ZenML client
        existing_links = client.list_model_version_artifact_links(
            ModelVersionArtifactFilterModel(
                user_id=client.active_user.id,
                workspace_id=client.active_workspace.id,
                name=artifact_name,
                model_id=self._model.id,
                model_version_id=self._model_version.id,
                only_artifacts=not (is_model_object or is_deployment),
                only_deployments=is_deployment,
                only_model_objects=is_model_object,
            )
        )
        if len(existing_links):
            if self.overwrite:
                # delete all model version artifact links by name
                logger.warning(
                    f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
                )
                client.zen_store.delete_model_version_artifact_link(
                    model_name_or_id=self._model.id,
                    model_version_name_or_id=self._model_version.id,
                    model_version_artifact_link_name_or_id=artifact_name,
                )
            else:
                logger.info(
                    f"Artifact link `{artifact_name}` already exists, adding new version."
                )
        client.zen_store.create_model_version_artifact_link(request)

    def link_to_model(
        self,
        artifact_uuid: UUID,
    ) -> None:
        """Link artifact to the model version.

        Args:
            artifact_uuid (UUID): The UUID of the artifact to link.
        """
        self._link_to_model_version(
            artifact_uuid,
            is_model_object=self.IS_MODEL_ARTIFACT,
            is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
        )
Config

Config class for ArtifactConfig.

Source code in zenml/model/artifact_config.py
class Config:
    """Config class for ArtifactConfig."""

    smart_union = True

Link artifact to the model version.

Parameters:

Name Type Description Default
artifact_uuid UUID

The UUID of the artifact to link.

required
Source code in zenml/model/artifact_config.py
def link_to_model(
    self,
    artifact_uuid: UUID,
) -> None:
    """Link artifact to the model version.

    Args:
        artifact_uuid (UUID): The UUID of the artifact to link.
    """
    self._link_to_model_version(
        artifact_uuid,
        is_model_object=self.IS_MODEL_ARTIFACT,
        is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
    )

DeploymentArtifactConfig (ArtifactConfig) pydantic-model

Used to link a Deployment to the model version.

Source code in zenml/model/artifact_config.py
class DeploymentArtifactConfig(ArtifactConfig):
    """Used to link a Deployment to the model version."""

    IS_DEPLOYMENT_ARTIFACT = True

ModelArtifactConfig (ArtifactConfig) pydantic-model

Used to link a Model Object to the model version.

save_to_model_registry: Whether to save the model object to the model registry.

Source code in zenml/model/artifact_config.py
class ModelArtifactConfig(ArtifactConfig):
    """Used to link a Model Object to the model version.

    save_to_model_registry: Whether to save the model object to the model registry.
    """

    save_to_model_registry: bool = True
    IS_MODEL_ARTIFACT = True

Utility functions for linking step outputs to model versions.

Log artifact metadata.

Parameters:

Name Type Description Default
output_name Optional[str]

The output name of the artifact to log metadata for. Can be omitted if there is only one output artifact.

None
artifact_config ArtifactConfig

The ArtifactConfig of how to link this output.

required
Source code in zenml/model/link_output_to_model.py
def link_output_to_model(
    artifact_config: "ArtifactConfig",
    output_name: Optional[str] = None,
) -> None:
    """Log artifact metadata.

    Args:
        output_name: The output name of the artifact to log metadata for. Can
            be omitted if there is only one output artifact.
        artifact_config: The ArtifactConfig of how to link this output.
    """
    from zenml.new.steps.step_context import get_step_context

    step_context = get_step_context()
    step_context._set_artifact_config(
        output_name=output_name, artifact_config=artifact_config
    )

model_config

ModelConfig user facing interface to pass into pipeline or step.

ModelConfig (ModelConfigModel) pydantic-model

ModelConfig class to pass into pipeline or step to set it into a model context.

name: The name of the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage, if skipped and create_new_model_version is False - latest model version will be used. version_description: The description of the model version. create_new_model_version: Whether to create a new model version during execution !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack. delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.

Source code in zenml/model/model_config.py
class ModelConfig(ModelConfigModel):
    """ModelConfig class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    version: The model version name, number or stage is optional and points model context
        to a specific version/stage, if skipped and `create_new_model_version` is False -
        latest model version will be used.
    version_description: The description of the model version.
    create_new_model_version: Whether to create a new model version during execution
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
    """

    _model: Optional["ModelResponseModel"] = PrivateAttr(default=None)
    _model_version: Optional["ModelVersionResponseModel"] = PrivateAttr(
        default=None
    )

    def get_or_create_model(self) -> "ModelResponseModel":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        if self._model is not None:
            return self._model

        from zenml.client import Client
        from zenml.models.model_models import ModelRequestModel

        zenml_client = Client()
        try:
            self._model = zenml_client.get_model(model_name_or_id=self.name)
        except KeyError:
            model_request = ModelRequestModel(
                name=self.name,
                license=self.license,
                description=self.description,
                audience=self.audience,
                use_cases=self.use_cases,
                limitations=self.limitations,
                trade_offs=self.trade_offs,
                ethic=self.ethic,
                tags=self.tags,
                user=zenml_client.active_user.id,
                workspace=zenml_client.active_workspace.id,
            )
            model_request = ModelRequestModel.parse_obj(model_request)
            try:
                self._model = zenml_client.create_model(model=model_request)
                logger.info(f"New model `{self.name}` was created implicitly.")
            except EntityExistsError:
                # this is backup logic, if model was created somehow in between get and create calls
                self._model = zenml_client.get_model(
                    model_name_or_id=self.name
                )

        return self._model

    def _create_model_version(
        self, model: "ModelResponseModel"
    ) -> "ModelVersionResponseModel":
        """This method creates a model version for Model Control Plane.

        Args:
            model: The model containing the model version.

        Returns:
            The model version based on configuration.
        """
        if self._model_version is not None:
            return self._model_version

        from zenml.client import Client
        from zenml.models.model_models import ModelVersionRequestModel

        zenml_client = Client()
        model_version_request = ModelVersionRequestModel(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=self.version,
            description=self.version_description,
            model=model.id,
        )
        mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
        try:
            mv = zenml_client.get_model_version(
                model_name_or_id=self.name,
                model_version_name_or_number_or_id=self.version,
            )
            self._model_version = mv
        except KeyError:
            self._model_version = zenml_client.create_model_version(
                model_version=mv_request
            )
            logger.info(f"New model version `{self.version}` was created.")

        return self._model_version

    def _get_model_version(self) -> "ModelVersionResponseModel":
        """This method gets a model version from Model Control Plane.

        Returns:
            The model version based on configuration.
        """
        if self._model_version is not None:
            return self._model_version

        from zenml.client import Client

        zenml_client = Client()
        if self.version is None:
            # raise if not found
            self._model_version = zenml_client.get_model_version(
                model_name_or_id=self.name
            )
        else:
            # by version name or stage or number
            # raise if not found
            self._model_version = zenml_client.get_model_version(
                model_name_or_id=self.name,
                model_version_name_or_number_or_id=self.version,
            )
        return self._model_version

    def get_or_create_model_version(self) -> "ModelVersionResponseModel":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model is fetched. Model
        name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model configuration:
        - If there is an existing model version leftover from the previous failed run with
        `delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
        leftover model version will be reused.
        - Otherwise if `create_new_model_version` is True, a new model version is created.
        - If `create_new_model_version` is False a model version will be fetched based on the version:
            - If `version` is not set, the latest model version will be fetched.
            - If `version` is set to a string, the model version with the matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

        Returns:
            The model version based on configuration.
        """
        model = self.get_or_create_model()
        if self.create_new_model_version:
            mv = self._create_model_version(model)
        else:
            mv = self._get_model_version()
        return mv

    def _merge_with_config(self, model_config: ModelConfigModel) -> None:
        self.license = self.license or model_config.license
        self.description = self.description or model_config.description
        self.audience = self.audience or model_config.audience
        self.use_cases = self.use_cases or model_config.use_cases
        self.limitations = self.limitations or model_config.limitations
        self.trade_offs = self.trade_offs or model_config.trade_offs
        self.ethic = self.ethic or model_config.ethic
        if model_config.tags is not None:
            self.tags = (self.tags or []) + model_config.tags

        self.delete_new_version_on_failure &= (
            model_config.delete_new_version_on_failure
        )
get_or_create_model(self)

This method should get or create a model from Model Control Plane.

New model is created implicitly, if missing, otherwise fetched.

Returns:

Type Description
ModelResponseModel

The model based on configuration.

Source code in zenml/model/model_config.py
def get_or_create_model(self) -> "ModelResponseModel":
    """This method should get or create a model from Model Control Plane.

    New model is created implicitly, if missing, otherwise fetched.

    Returns:
        The model based on configuration.
    """
    if self._model is not None:
        return self._model

    from zenml.client import Client
    from zenml.models.model_models import ModelRequestModel

    zenml_client = Client()
    try:
        self._model = zenml_client.get_model(model_name_or_id=self.name)
    except KeyError:
        model_request = ModelRequestModel(
            name=self.name,
            license=self.license,
            description=self.description,
            audience=self.audience,
            use_cases=self.use_cases,
            limitations=self.limitations,
            trade_offs=self.trade_offs,
            ethic=self.ethic,
            tags=self.tags,
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
        )
        model_request = ModelRequestModel.parse_obj(model_request)
        try:
            self._model = zenml_client.create_model(model=model_request)
            logger.info(f"New model `{self.name}` was created implicitly.")
        except EntityExistsError:
            # this is backup logic, if model was created somehow in between get and create calls
            self._model = zenml_client.get_model(
                model_name_or_id=self.name
            )

    return self._model
get_or_create_model_version(self)

This method should get or create a model and a model version from Model Control Plane.

A new model is created implicitly if missing, otherwise existing model is fetched. Model name is controlled by the name parameter.

Model Version returned by this method is resolved based on model configuration: - If there is an existing model version leftover from the previous failed run with delete_new_version_on_failure is set to False and create_new_model_version is True, leftover model version will be reused. - Otherwise if create_new_model_version is True, a new model version is created. - If create_new_model_version is False a model version will be fetched based on the version: - If version is not set, the latest model version will be fetched. - If version is set to a string, the model version with the matching version will be fetched. - If version is set to a ModelStage, the model version with the matching stage will be fetched.

Returns:

Type Description
ModelVersionResponseModel

The model version based on configuration.

Source code in zenml/model/model_config.py
def get_or_create_model_version(self) -> "ModelVersionResponseModel":
    """This method should get or create a model and a model version from Model Control Plane.

    A new model is created implicitly if missing, otherwise existing model is fetched. Model
    name is controlled by the `name` parameter.

    Model Version returned by this method is resolved based on model configuration:
    - If there is an existing model version leftover from the previous failed run with
    `delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
    leftover model version will be reused.
    - Otherwise if `create_new_model_version` is True, a new model version is created.
    - If `create_new_model_version` is False a model version will be fetched based on the version:
        - If `version` is not set, the latest model version will be fetched.
        - If `version` is set to a string, the model version with the matching version will be fetched.
        - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

    Returns:
        The model version based on configuration.
    """
    model = self.get_or_create_model()
    if self.create_new_model_version:
        mv = self._create_model_version(model)
    else:
        mv = self._get_model_version()
    return mv