Skip to content

Model

zenml.model special

Initialization of ZenML model. ZenML model support Model Control Plane feature.

artifact_config

Artifact Config classes to support Model Control Plane feature.

ArtifactConfig (BaseModel) pydantic-model

Used to link a generic Artifact to the model version.

model_name: The name of the model to link artifact to. !!! model_version "The identifier of the model version to link artifact to." It can be exact version ("23"), exact version number (42), stage (ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions.

Source code in zenml/model/artifact_config.py
class ArtifactConfig(BaseModel):
    """Used to link a generic Artifact to the model version.

    model_name: The name of the model to link artifact to.
    model_version: The identifier of the model version to link artifact to.
        It can be exact version ("23"), exact version number (42), stage
        (ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version.
    model_stage: The stage of the model version to link artifact to.
    artifact_name: The override name of a link instead of an artifact name.
    overwrite: Whether to overwrite an existing link or create new versions.
    """

    model_name: Optional[str]
    model_version: Optional[Union[ModelStages, str, int]]
    artifact_name: Optional[str]
    overwrite: bool = False

    _pipeline_name: str = PrivateAttr()
    _step_name: str = PrivateAttr()
    IS_MODEL_ARTIFACT: ClassVar[bool] = False
    IS_DEPLOYMENT_ARTIFACT: ClassVar[bool] = False

    class Config:
        """Config class for ArtifactConfig."""

        smart_union = True

    @property
    def _model_config(self) -> "ModelConfig":
        """Property that returns the model configuration.

        Returns:
            ModelConfig: The model configuration.

        Raises:
            RuntimeError: If model configuration cannot be acquired from @step
                or @pipeline or built on the fly from fields of this class.
        """
        try:
            model_config = get_step_context().model_config
        except (StepContextError, RuntimeError):
            model_config = None
        # Check if a specific model name is provided and it doesn't match the context name
        if (self.model_name is not None) and (
            model_config is None or model_config.name != self.model_name
        ):
            # Create a new ModelConfig instance with the provided model name and version
            from zenml.model.model_config import ModelConfig

            on_the_fly_config = ModelConfig(
                name=self.model_name,
                version=self.model_version,
            )
            return on_the_fly_config

        if model_config is None:
            raise RuntimeError(
                "No model configuration found in @step or @pipeline. "
                "You can configure ModelConfig inside ArtifactConfig as well, but "
                "`model_name` and `model_version` must be provided."
            )
        # Return the model from the context
        return model_config

    def _link_to_model_version(
        self,
        artifact_uuid: UUID,
        model_config: "ModelConfig",
        is_model_object: bool = False,
        is_deployment: bool = False,
    ) -> None:
        """Link artifact to the model version.

        This method is used on exit from the step context to link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            model_config: The model configuration from caller.
            is_model_object: Whether the artifact is a model object. Defaults to False.
            is_deployment: Whether the artifact is a deployment. Defaults to False.
        """
        from zenml.client import Client
        from zenml.models.model_models import (
            ModelVersionArtifactFilterModel,
            ModelVersionArtifactRequestModel,
        )

        # Create a ZenML client
        client = Client()

        model_version = model_config.get_or_create_model_version()

        artifact_name = self.artifact_name
        if artifact_name is None:
            artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid)
            artifact_name = artifact.name

        # Create a request model for the model version artifact link
        request = ModelVersionArtifactRequestModel(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            name=artifact_name,
            artifact=artifact_uuid,
            model=model_version.model.id,
            model_version=model_version.id,
            is_model_object=is_model_object,
            is_deployment=is_deployment,
            overwrite=self.overwrite,
            pipeline_name=self._pipeline_name,
            step_name=self._step_name,
        )

        # Create the model version artifact link using the ZenML client
        existing_links = client.list_model_version_artifact_links(
            model_name_or_id=model_version.model.id,
            model_version_name_or_number_or_id=model_version.id,
            model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
                user_id=client.active_user.id,
                workspace_id=client.active_workspace.id,
                name=artifact_name,
                only_artifacts=not (is_model_object or is_deployment),
                only_deployments=is_deployment,
                only_model_objects=is_model_object,
                pipeline_name=self._pipeline_name,
                step_name=self._step_name,
            ),
        )
        if len(existing_links):
            if self.overwrite:
                # delete all model version artifact links by name
                logger.warning(
                    f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
                )
                client.zen_store.delete_model_version_artifact_link(
                    model_name_or_id=model_version.model.id,
                    model_version_name_or_id=model_version.id,
                    model_version_artifact_link_name_or_id=artifact_name,
                )
            else:
                logger.info(
                    f"Artifact link `{artifact_name}` already exists, adding new version."
                )
        client.zen_store.create_model_version_artifact_link(request)

    def link_to_model(
        self, artifact_uuid: UUID, model_config: "ModelConfig"
    ) -> None:
        """Link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            model_config: The model configuration from caller.
        """
        self._link_to_model_version(
            artifact_uuid,
            model_config=model_config,
            is_model_object=self.IS_MODEL_ARTIFACT,
            is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
        )
Config

Config class for ArtifactConfig.

Source code in zenml/model/artifact_config.py
class Config:
    """Config class for ArtifactConfig."""

    smart_union = True

Link artifact to the model version.

Parameters:

Name Type Description Default
artifact_uuid UUID

The UUID of the artifact to link.

required
model_config ModelConfig

The model configuration from caller.

required
Source code in zenml/model/artifact_config.py
def link_to_model(
    self, artifact_uuid: UUID, model_config: "ModelConfig"
) -> None:
    """Link artifact to the model version.

    Args:
        artifact_uuid: The UUID of the artifact to link.
        model_config: The model configuration from caller.
    """
    self._link_to_model_version(
        artifact_uuid,
        model_config=model_config,
        is_model_object=self.IS_MODEL_ARTIFACT,
        is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
    )

DeploymentArtifactConfig (ArtifactConfig) pydantic-model

Used to link a Deployment to the model version.

Source code in zenml/model/artifact_config.py
class DeploymentArtifactConfig(ArtifactConfig):
    """Used to link a Deployment to the model version."""

    IS_DEPLOYMENT_ARTIFACT = True

ModelArtifactConfig (ArtifactConfig) pydantic-model

Used to link a Model Object to the model version.

save_to_model_registry: Whether to save the model object to the model registry.

Source code in zenml/model/artifact_config.py
class ModelArtifactConfig(ArtifactConfig):
    """Used to link a Model Object to the model version.

    save_to_model_registry: Whether to save the model object to the model registry.
    """

    save_to_model_registry: bool = True
    IS_MODEL_ARTIFACT = True

Utility functions for linking step outputs to model versions.

Log artifact metadata.

Parameters:

Name Type Description Default
output_name Optional[str]

The output name of the artifact to log metadata for. Can be omitted if there is only one output artifact.

None
artifact_config ArtifactConfig

The ArtifactConfig of how to link this output.

required
Source code in zenml/model/link_output_to_model.py
def link_output_to_model(
    artifact_config: "ArtifactConfig",
    output_name: Optional[str] = None,
) -> None:
    """Log artifact metadata.

    Args:
        output_name: The output name of the artifact to log metadata for. Can
            be omitted if there is only one output artifact.
        artifact_config: The ArtifactConfig of how to link this output.
    """
    from zenml.new.steps.step_context import get_step_context

    step_context = get_step_context()
    step_context._set_artifact_config(
        output_name=output_name, artifact_config=artifact_config
    )

model_config

ModelConfig user facing interface to pass into pipeline or step.

ModelConfig (BaseModel) pydantic-model

ModelConfig class to pass into pipeline or step to set it into a model context.

name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage. If skipped new model version will be created. version_description: The description of the model version. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack. delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.

Source code in zenml/model/model_config.py
class ModelConfig(BaseModel):
    """ModelConfig class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    license: The license under which the model is created.
    description: The description of the model.
    audience: The target audience of the model.
    use_cases: The use cases of the model.
    limitations: The known limitations of the model.
    trade_offs: The tradeoffs of the model.
    ethics: The ethical implications of the model.
    tags: Tags associated with the model.
    version: The model version name, number or stage is optional and points model context
        to a specific version/stage. If skipped new model version will be created.
    version_description: The description of the model version.
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
    """

    name: str
    license: Optional[str]
    description: Optional[str]
    audience: Optional[str]
    use_cases: Optional[str]
    limitations: Optional[str]
    trade_offs: Optional[str]
    ethics: Optional[str]
    tags: Optional[List[str]]
    version: Optional[Union[ModelStages, int, str]]
    version_description: Optional[str]
    save_models_to_registry: bool = True
    delete_new_version_on_failure: bool = True

    suppress_class_validation_warnings: bool = False
    was_created_in_this_run: bool = False

    class Config:
        """Config class."""

        smart_union = True

    @root_validator(pre=True)
    def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            values: Dict of values.

        Returns:
            Dict of validated values.
        """
        suppress_class_validation_warnings = values.get(
            "suppress_class_validation_warnings", False
        )
        version = values.get("version", None)

        if (
            version in [stage.value for stage in ModelStages]
            and not suppress_class_validation_warnings
        ):
            logger.info(
                f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage."
            )
        if str(version).isnumeric() and not suppress_class_validation_warnings:
            logger.info(
                f"`version` `{version}` is numeric and will be fetched using version number."
            )
        values["suppress_class_validation_warnings"] = True
        return values

    def _validate_config_in_runtime(self) -> None:
        """Validate that config doesn't conflict with runtime environment.

        Raises:
            RuntimeError: If recovery not requested, but model version already exists.
            RuntimeError: If there is unfinished pipeline run for requested new model version.
        """
        try:
            model_version = self._get_model_version()
            if self.version is None or self.version == RUNNING_MODEL_VERSION:
                self.version = RUNNING_MODEL_VERSION
                for run_name, run in model_version.pipeline_runs.items():
                    if run.status == ExecutionStatus.RUNNING:
                        raise RuntimeError(
                            "You have configured a model context without explicit "
                            "`version` argument passed in, so new a unnamed model "
                            "version has to be created, but pipeline run "
                            f"`{run_name}` has not finished yet. To proceed you can:\n"
                            "- Wait for previous run to finish\n"
                            "- Provide explicit `version` in configuration"
                        )
                if self.delete_new_version_on_failure:
                    raise RuntimeError(
                        f"Cannot create version `{self.version}` "
                        f"for model `{self.name}` since it already exists "
                        "and recovery mode is disabled. "
                        "This could happen for unforeseen reasons (e.g. unexpected "
                        "interruption of previous pipeline run flow).\n"
                        "If you would like to remove the staling version use "
                        "following CLI command:\n"
                        f"`zenml model version delete {self.name} {self.version}`"
                    )
        except KeyError:
            self.get_or_create_model_version()

    def get_or_create_model(self) -> "ModelResponseModel":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        from zenml.client import Client
        from zenml.models.model_models import ModelRequestModel

        zenml_client = Client()
        try:
            model = zenml_client.get_model(model_name_or_id=self.name)
        except KeyError:
            model_request = ModelRequestModel(
                name=self.name,
                license=self.license,
                description=self.description,
                audience=self.audience,
                use_cases=self.use_cases,
                limitations=self.limitations,
                trade_offs=self.trade_offs,
                ethics=self.ethics,
                tags=self.tags,
                user=zenml_client.active_user.id,
                workspace=zenml_client.active_workspace.id,
            )
            model_request = ModelRequestModel.parse_obj(model_request)
            try:
                model = zenml_client.create_model(model=model_request)
                logger.info(f"New model `{self.name}` was created implicitly.")
            except EntityExistsError:
                # this is backup logic, if model was created somehow in between get and create calls
                pass
            finally:
                model = zenml_client.get_model(model_name_or_id=self.name)

        return model

    def _get_model_version(self) -> "ModelVersionResponseModel":
        """This method gets a model version from Model Control Plane.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client

        zenml_client = Client()
        return zenml_client.get_model_version(
            model_name_or_id=self.name,
            model_version_name_or_number_or_id=self.version
            or RUNNING_MODEL_VERSION,
        )

    def get_or_create_model_version(self) -> "ModelVersionResponseModel":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model is fetched. Model
        name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model configuration:
        - If there is an existing model version leftover from the previous failed run with
        `delete_new_version_on_failure` is set to False and `version` is None,
        leftover model version will be reused.
        - Otherwise if `version` is None, a new model version is created.
        - If `version` is not None a model version will be fetched based on the version:
            - If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
            - If `version` is set to a string, the model version with the matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client
        from zenml.models.model_models import ModelVersionRequestModel

        model = self.get_or_create_model()

        if self.version is None:
            logger.info(
                "Creation of new model version was requested, but no version name was explicitly provided. "
                f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
            )
            self.version = RUNNING_MODEL_VERSION

        zenml_client = Client()
        model_version_request = ModelVersionRequestModel(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=self.version,
            description=self.version_description,
            model=model.id,
        )
        mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
        try:
            model_version = self._get_model_version()
        except KeyError:
            model_version = zenml_client.create_model_version(
                model_version=mv_request
            )
            self.was_created_in_this_run = True
            logger.info(f"New model version `{self.version}` was created.")
        return model_version

    def _merge(self, model_config: "ModelConfig") -> None:
        self.license = self.license or model_config.license
        self.description = self.description or model_config.description
        self.audience = self.audience or model_config.audience
        self.use_cases = self.use_cases or model_config.use_cases
        self.limitations = self.limitations or model_config.limitations
        self.trade_offs = self.trade_offs or model_config.trade_offs
        self.ethics = self.ethics or model_config.ethics
        if model_config.tags is not None:
            self.tags = list(
                {t for t in self.tags or []}.union(set(model_config.tags))
            )

        self.delete_new_version_on_failure &= (
            model_config.delete_new_version_on_failure
        )

    def __hash__(self) -> int:
        """Get hash of the `ModelConfig`.

        Returns:
            Hash function results
        """
        return hash(
            "::".join(
                (
                    str(v)
                    for v in (
                        self.name,
                        self.version,
                        self.delete_new_version_on_failure,
                    )
                )
            )
        )
Config

Config class.

Source code in zenml/model/model_config.py
class Config:
    """Config class."""

    smart_union = True
__hash__(self) special

Get hash of the ModelConfig.

Returns:

Type Description
int

Hash function results

Source code in zenml/model/model_config.py
def __hash__(self) -> int:
    """Get hash of the `ModelConfig`.

    Returns:
        Hash function results
    """
    return hash(
        "::".join(
            (
                str(v)
                for v in (
                    self.name,
                    self.version,
                    self.delete_new_version_on_failure,
                )
            )
        )
    )
get_or_create_model(self)

This method should get or create a model from Model Control Plane.

New model is created implicitly, if missing, otherwise fetched.

Returns:

Type Description
ModelResponseModel

The model based on configuration.

Source code in zenml/model/model_config.py
def get_or_create_model(self) -> "ModelResponseModel":
    """This method should get or create a model from Model Control Plane.

    New model is created implicitly, if missing, otherwise fetched.

    Returns:
        The model based on configuration.
    """
    from zenml.client import Client
    from zenml.models.model_models import ModelRequestModel

    zenml_client = Client()
    try:
        model = zenml_client.get_model(model_name_or_id=self.name)
    except KeyError:
        model_request = ModelRequestModel(
            name=self.name,
            license=self.license,
            description=self.description,
            audience=self.audience,
            use_cases=self.use_cases,
            limitations=self.limitations,
            trade_offs=self.trade_offs,
            ethics=self.ethics,
            tags=self.tags,
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
        )
        model_request = ModelRequestModel.parse_obj(model_request)
        try:
            model = zenml_client.create_model(model=model_request)
            logger.info(f"New model `{self.name}` was created implicitly.")
        except EntityExistsError:
            # this is backup logic, if model was created somehow in between get and create calls
            pass
        finally:
            model = zenml_client.get_model(model_name_or_id=self.name)

    return model
get_or_create_model_version(self)

This method should get or create a model and a model version from Model Control Plane.

A new model is created implicitly if missing, otherwise existing model is fetched. Model name is controlled by the name parameter.

Model Version returned by this method is resolved based on model configuration: - If there is an existing model version leftover from the previous failed run with delete_new_version_on_failure is set to False and version is None, leftover model version will be reused. - Otherwise if version is None, a new model version is created. - If version is not None a model version will be fetched based on the version: - If version is set to an integer or digit string, the model version with the matching number will be fetched. - If version is set to a string, the model version with the matching version will be fetched. - If version is set to a ModelStage, the model version with the matching stage will be fetched.

Returns:

Type Description
ModelVersionResponseModel

The model version based on configuration.

Source code in zenml/model/model_config.py
def get_or_create_model_version(self) -> "ModelVersionResponseModel":
    """This method should get or create a model and a model version from Model Control Plane.

    A new model is created implicitly if missing, otherwise existing model is fetched. Model
    name is controlled by the `name` parameter.

    Model Version returned by this method is resolved based on model configuration:
    - If there is an existing model version leftover from the previous failed run with
    `delete_new_version_on_failure` is set to False and `version` is None,
    leftover model version will be reused.
    - Otherwise if `version` is None, a new model version is created.
    - If `version` is not None a model version will be fetched based on the version:
        - If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
        - If `version` is set to a string, the model version with the matching version will be fetched.
        - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

    Returns:
        The model version based on configuration.
    """
    from zenml.client import Client
    from zenml.models.model_models import ModelVersionRequestModel

    model = self.get_or_create_model()

    if self.version is None:
        logger.info(
            "Creation of new model version was requested, but no version name was explicitly provided. "
            f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
        )
        self.version = RUNNING_MODEL_VERSION

    zenml_client = Client()
    model_version_request = ModelVersionRequestModel(
        user=zenml_client.active_user.id,
        workspace=zenml_client.active_workspace.id,
        name=self.version,
        description=self.version_description,
        model=model.id,
    )
    mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
    try:
        model_version = self._get_model_version()
    except KeyError:
        model_version = zenml_client.create_model_version(
            model_version=mv_request
        )
        self.was_created_in_this_run = True
        logger.info(f"New model version `{self.version}` was created.")
    return model_version