Skip to content

Model

zenml.model special

Initialization of ZenML model. ZenML model support Model Control Plane feature.

artifact_config

Artifact Config classes to support Model Control Plane feature.

ArtifactConfig (BaseModel) pydantic-model

Used to link a generic Artifact to the model version.

model_name: The name of the model to link artifact to. !!! model_version "The identifier of the model version to link artifact to." It can be exact version ("23"), exact version number (42), stage (ModelStages.PRODUCTION) or None for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions.

Source code in zenml/model/artifact_config.py
class ArtifactConfig(BaseModel):
    """Used to link a generic Artifact to the model version.

    model_name: The name of the model to link artifact to.
    model_version: The identifier of the model version to link artifact to.
        It can be exact version ("23"), exact version number (42), stage
        (ModelStages.PRODUCTION) or None for the latest version.
    model_stage: The stage of the model version to link artifact to.
    artifact_name: The override name of a link instead of an artifact name.
    overwrite: Whether to overwrite an existing link or create new versions.
    """

    model_name: Optional[str]
    model_version: Optional[Union[ModelStages, str, int]]
    artifact_name: Optional[str]
    overwrite: bool = False

    _pipeline_name: str = PrivateAttr()
    _step_name: str = PrivateAttr()
    IS_MODEL_ARTIFACT: ClassVar[bool] = False
    IS_DEPLOYMENT_ARTIFACT: ClassVar[bool] = False

    class Config:
        """Config class for ArtifactConfig."""

        smart_union = True

    @property
    def _model_config(self) -> "ModelConfig":
        """Property that returns the model configuration.

        Returns:
            ModelConfig: The model configuration.

        Raises:
            RuntimeError: If model configuration cannot be acquired from @step
                or @pipeline or built on the fly from fields of this class.
        """
        try:
            model_config = get_step_context().model_config
        except (StepContextError, RuntimeError):
            model_config = None
        # Check if a specific model name is provided and it doesn't match the context name
        if (self.model_name is not None) and (
            model_config is None or model_config.name != self.model_name
        ):
            # Create a new ModelConfig instance with the provided model name and version
            from zenml.model.model_config import ModelConfig

            on_the_fly_config = ModelConfig(
                name=self.model_name,
                version=self.model_version,
                create_new_model_version=False,
            )
            return on_the_fly_config

        if model_config is None:
            raise RuntimeError(
                "No model configuration found in @step or @pipeline. "
                "You can configure ModelConfig inside ArtifactConfig as well, but "
                "`model_name` and `model_version` must be provided."
            )
        # Return the model from the context
        return model_config

    def _link_to_model_version(
        self,
        artifact_uuid: UUID,
        model_config: "ModelConfig",
        is_model_object: bool = False,
        is_deployment: bool = False,
    ) -> None:
        """Link artifact to the model version.

        This method is used on exit from the step context to link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            model_config: The model configuration from caller.
            is_model_object: Whether the artifact is a model object. Defaults to False.
            is_deployment: Whether the artifact is a deployment. Defaults to False.
        """
        from zenml.client import Client
        from zenml.models.model_models import (
            ModelVersionArtifactFilterModel,
            ModelVersionArtifactRequestModel,
        )

        # Create a ZenML client
        client = Client()

        model_version = model_config._get_model_version()

        artifact_name = self.artifact_name
        if artifact_name is None:
            artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid)
            artifact_name = artifact.name

        # Create a request model for the model version artifact link
        request = ModelVersionArtifactRequestModel(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            name=artifact_name,
            artifact=artifact_uuid,
            model=model_version.model.id,
            model_version=model_version.id,
            is_model_object=is_model_object,
            is_deployment=is_deployment,
            overwrite=self.overwrite,
            pipeline_name=self._pipeline_name,
            step_name=self._step_name,
        )

        # Create the model version artifact link using the ZenML client
        existing_links = client.list_model_version_artifact_links(
            ModelVersionArtifactFilterModel(
                user_id=client.active_user.id,
                workspace_id=client.active_workspace.id,
                name=artifact_name,
                model_id=model_version.model.id,
                model_version_id=model_version.id,
                only_artifacts=not (is_model_object or is_deployment),
                only_deployments=is_deployment,
                only_model_objects=is_model_object,
                pipeline_name=self._pipeline_name,
                step_name=self._step_name,
            )
        )
        if len(existing_links):
            if self.overwrite:
                # delete all model version artifact links by name
                logger.warning(
                    f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
                )
                client.zen_store.delete_model_version_artifact_link(
                    model_name_or_id=model_version.model.id,
                    model_version_name_or_id=model_version.id,
                    model_version_artifact_link_name_or_id=artifact_name,
                )
            else:
                logger.info(
                    f"Artifact link `{artifact_name}` already exists, adding new version."
                )
        client.zen_store.create_model_version_artifact_link(request)

    def link_to_model(
        self, artifact_uuid: UUID, model_config: "ModelConfig"
    ) -> None:
        """Link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            model_config: The model configuration from caller.
        """
        self._link_to_model_version(
            artifact_uuid,
            model_config=model_config,
            is_model_object=self.IS_MODEL_ARTIFACT,
            is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
        )
Config

Config class for ArtifactConfig.

Source code in zenml/model/artifact_config.py
class Config:
    """Config class for ArtifactConfig."""

    smart_union = True

Link artifact to the model version.

Parameters:

Name Type Description Default
artifact_uuid UUID

The UUID of the artifact to link.

required
model_config ModelConfig

The model configuration from caller.

required
Source code in zenml/model/artifact_config.py
def link_to_model(
    self, artifact_uuid: UUID, model_config: "ModelConfig"
) -> None:
    """Link artifact to the model version.

    Args:
        artifact_uuid: The UUID of the artifact to link.
        model_config: The model configuration from caller.
    """
    self._link_to_model_version(
        artifact_uuid,
        model_config=model_config,
        is_model_object=self.IS_MODEL_ARTIFACT,
        is_deployment=self.IS_DEPLOYMENT_ARTIFACT,
    )

DeploymentArtifactConfig (ArtifactConfig) pydantic-model

Used to link a Deployment to the model version.

Source code in zenml/model/artifact_config.py
class DeploymentArtifactConfig(ArtifactConfig):
    """Used to link a Deployment to the model version."""

    IS_DEPLOYMENT_ARTIFACT = True

ModelArtifactConfig (ArtifactConfig) pydantic-model

Used to link a Model Object to the model version.

save_to_model_registry: Whether to save the model object to the model registry.

Source code in zenml/model/artifact_config.py
class ModelArtifactConfig(ArtifactConfig):
    """Used to link a Model Object to the model version.

    save_to_model_registry: Whether to save the model object to the model registry.
    """

    save_to_model_registry: bool = True
    IS_MODEL_ARTIFACT = True

Utility functions for linking step outputs to model versions.

Log artifact metadata.

Parameters:

Name Type Description Default
output_name Optional[str]

The output name of the artifact to log metadata for. Can be omitted if there is only one output artifact.

None
artifact_config ArtifactConfig

The ArtifactConfig of how to link this output.

required
Source code in zenml/model/link_output_to_model.py
def link_output_to_model(
    artifact_config: "ArtifactConfig",
    output_name: Optional[str] = None,
) -> None:
    """Log artifact metadata.

    Args:
        output_name: The output name of the artifact to log metadata for. Can
            be omitted if there is only one output artifact.
        artifact_config: The ArtifactConfig of how to link this output.
    """
    from zenml.new.steps.step_context import get_step_context

    step_context = get_step_context()
    step_context._set_artifact_config(
        output_name=output_name, artifact_config=artifact_config
    )

model_config

ModelConfig user facing interface to pass into pipeline or step.

ModelConfig (BaseModel) pydantic-model

ModelConfig class to pass into pipeline or step to set it into a model context.

name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage. If skipped and create_new_model_version is False - latest model version will be used. version_description: The description of the model version. create_new_model_version: Whether to create a new model version during execution !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack. delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.

Source code in zenml/model/model_config.py
class ModelConfig(BaseModel):
    """ModelConfig class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    license: The license under which the model is created.
    description: The description of the model.
    audience: The target audience of the model.
    use_cases: The use cases of the model.
    limitations: The known limitations of the model.
    trade_offs: The tradeoffs of the model.
    ethics: The ethical implications of the model.
    tags: Tags associated with the model.
    version: The model version name, number or stage is optional and points model context
        to a specific version/stage. If skipped and `create_new_model_version` is False -
        latest model version will be used.
    version_description: The description of the model version.
    create_new_model_version: Whether to create a new model version during execution
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
    """

    name: str
    license: Optional[str]
    description: Optional[str]
    audience: Optional[str]
    use_cases: Optional[str]
    limitations: Optional[str]
    trade_offs: Optional[str]
    ethics: Optional[str]
    tags: Optional[List[str]]
    version: Optional[Union[ModelStages, int, str]]
    version_description: Optional[str]
    create_new_model_version: bool = False
    save_models_to_registry: bool = True
    delete_new_version_on_failure: bool = True

    suppress_class_validation_warnings: bool = False

    class Config:
        """Config class."""

        smart_union = True

    @root_validator(pre=True)
    def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            values: Dict of values.

        Returns:
            Dict of validated values.

        Raises:
            ValueError: If validation failed on one of the checks.
        """
        create_new_model_version = values.get(
            "create_new_model_version", False
        )
        suppress_class_validation_warnings = values.get(
            "suppress_class_validation_warnings", False
        )
        version = values.get("version", None)

        if create_new_model_version:
            misuse_message = (
                "`version` set to {set} cannot be used with `create_new_model_version`."
                "You can leave it default or set to a non-stage and non-numeric string.\n"
                "Examples:\n"
                " - `version` set to 1 or '1' is interpreted as a version number\n"
                " - `version` set to 'production' is interpreted as a stage\n"
                " - `version` set to 'my_first_version_in_2023' is a valid version to be created\n"
                " - `version` set to 'My Second Version!' is a valid version to be created\n"
            )
            if isinstance(version, ModelStages) or version in [
                stage.value for stage in ModelStages
            ]:
                raise ValueError(
                    misuse_message.format(set="a `ModelStages` instance")
                )
            if str(version).isnumeric():
                raise ValueError(misuse_message.format(set="a numeric value"))
            if version is None:
                if not suppress_class_validation_warnings:
                    logger.info(
                        "Creation of new model version was requested, but no version name was explicitly provided. "
                        f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
                    )
                values["version"] = RUNNING_MODEL_VERSION
        if (
            version in [stage.value for stage in ModelStages]
            and not suppress_class_validation_warnings
        ):
            logger.info(
                f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage."
            )
        if str(version).isnumeric() and not suppress_class_validation_warnings:
            logger.info(
                f"`version` `{version}` is numeric and will be fetched using version number."
            )
        values["suppress_class_validation_warnings"] = True
        return values

    def _validate_config_in_runtime(self) -> None:
        """Validate that config doesn't conflict with runtime environment.

        Raises:
            RuntimeError: If recovery not requested, but model version already exists.
            RuntimeError: If there is unfinished pipeline run for requested new model version.
        """
        try:
            model_version = self._get_model_version()
            if self.create_new_model_version:
                for run_name, run in model_version.pipeline_runs.items():
                    if run.status == ExecutionStatus.RUNNING:
                        raise RuntimeError(
                            f"New model version was requested, but pipeline run `{run_name}` "
                            f"is still running with version `{model_version.name}`."
                        )

                if self.delete_new_version_on_failure:
                    raise RuntimeError(
                        f"Cannot create version `{self.version}` "
                        f"for model `{self.name}` since it already exists"
                    )
        except KeyError:
            self.get_or_create_model_version()

    def get_or_create_model(self) -> "ModelResponseModel":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        from zenml.client import Client
        from zenml.models.model_models import ModelRequestModel

        zenml_client = Client()
        try:
            model = zenml_client.get_model(model_name_or_id=self.name)
        except KeyError:
            model_request = ModelRequestModel(
                name=self.name,
                license=self.license,
                description=self.description,
                audience=self.audience,
                use_cases=self.use_cases,
                limitations=self.limitations,
                trade_offs=self.trade_offs,
                ethics=self.ethics,
                tags=self.tags,
                user=zenml_client.active_user.id,
                workspace=zenml_client.active_workspace.id,
            )
            model_request = ModelRequestModel.parse_obj(model_request)
            try:
                model = zenml_client.create_model(model=model_request)
                logger.info(f"New model `{self.name}` was created implicitly.")
            except EntityExistsError:
                # this is backup logic, if model was created somehow in between get and create calls
                model = zenml_client.get_model(model_name_or_id=self.name)

        return model

    def _create_model_version(
        self, model: "ModelResponseModel"
    ) -> "ModelVersionResponseModel":
        """This method creates a model version for Model Control Plane.

        Args:
            model: The model containing the model version.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client
        from zenml.models.model_models import ModelVersionRequestModel

        zenml_client = Client()
        model_version_request = ModelVersionRequestModel(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=self.version,
            description=self.version_description,
            model=model.id,
        )
        mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
        try:
            mv = zenml_client.get_model_version(
                model_name_or_id=self.name,
                model_version_name_or_number_or_id=self.version,
            )
            model_version = mv
        except KeyError:
            model_version = zenml_client.create_model_version(
                model_version=mv_request
            )
            logger.info(f"New model version `{self.version}` was created.")

        return model_version

    def _get_model_version(self) -> "ModelVersionResponseModel":
        """This method gets a model version from Model Control Plane.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client

        zenml_client = Client()
        if self.version is None:
            # raise if not found
            model_version = zenml_client.get_model_version(
                model_name_or_id=self.name
            )
        else:
            # by version name or stage or number
            # raise if not found
            model_version = zenml_client.get_model_version(
                model_name_or_id=self.name,
                model_version_name_or_number_or_id=self.version,
            )
        return model_version

    def get_or_create_model_version(self) -> "ModelVersionResponseModel":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model is fetched. Model
        name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model configuration:
        - If there is an existing model version leftover from the previous failed run with
        `delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
        leftover model version will be reused.
        - Otherwise if `create_new_model_version` is True, a new model version is created.
        - If `create_new_model_version` is False a model version will be fetched based on the version:
            - If `version` is not set, the latest model version will be fetched.
            - If `version` is set to a string, the model version with the matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

        Returns:
            The model version based on configuration.
        """
        model = self.get_or_create_model()
        if self.create_new_model_version:
            mv = self._create_model_version(model)
        else:
            mv = self._get_model_version()
        return mv

    def _merge(self, model_config: "ModelConfig") -> None:
        self.license = self.license or model_config.license
        self.description = self.description or model_config.description
        self.audience = self.audience or model_config.audience
        self.use_cases = self.use_cases or model_config.use_cases
        self.limitations = self.limitations or model_config.limitations
        self.trade_offs = self.trade_offs or model_config.trade_offs
        self.ethics = self.ethics or model_config.ethics
        if model_config.tags is not None:
            self.tags = list(
                {t for t in self.tags or []}.union(set(model_config.tags))
            )

        self.delete_new_version_on_failure &= (
            model_config.delete_new_version_on_failure
        )
Config

Config class.

Source code in zenml/model/model_config.py
class Config:
    """Config class."""

    smart_union = True
get_or_create_model(self)

This method should get or create a model from Model Control Plane.

New model is created implicitly, if missing, otherwise fetched.

Returns:

Type Description
ModelResponseModel

The model based on configuration.

Source code in zenml/model/model_config.py
def get_or_create_model(self) -> "ModelResponseModel":
    """This method should get or create a model from Model Control Plane.

    New model is created implicitly, if missing, otherwise fetched.

    Returns:
        The model based on configuration.
    """
    from zenml.client import Client
    from zenml.models.model_models import ModelRequestModel

    zenml_client = Client()
    try:
        model = zenml_client.get_model(model_name_or_id=self.name)
    except KeyError:
        model_request = ModelRequestModel(
            name=self.name,
            license=self.license,
            description=self.description,
            audience=self.audience,
            use_cases=self.use_cases,
            limitations=self.limitations,
            trade_offs=self.trade_offs,
            ethics=self.ethics,
            tags=self.tags,
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
        )
        model_request = ModelRequestModel.parse_obj(model_request)
        try:
            model = zenml_client.create_model(model=model_request)
            logger.info(f"New model `{self.name}` was created implicitly.")
        except EntityExistsError:
            # this is backup logic, if model was created somehow in between get and create calls
            model = zenml_client.get_model(model_name_or_id=self.name)

    return model
get_or_create_model_version(self)

This method should get or create a model and a model version from Model Control Plane.

A new model is created implicitly if missing, otherwise existing model is fetched. Model name is controlled by the name parameter.

Model Version returned by this method is resolved based on model configuration: - If there is an existing model version leftover from the previous failed run with delete_new_version_on_failure is set to False and create_new_model_version is True, leftover model version will be reused. - Otherwise if create_new_model_version is True, a new model version is created. - If create_new_model_version is False a model version will be fetched based on the version: - If version is not set, the latest model version will be fetched. - If version is set to a string, the model version with the matching version will be fetched. - If version is set to a ModelStage, the model version with the matching stage will be fetched.

Returns:

Type Description
ModelVersionResponseModel

The model version based on configuration.

Source code in zenml/model/model_config.py
def get_or_create_model_version(self) -> "ModelVersionResponseModel":
    """This method should get or create a model and a model version from Model Control Plane.

    A new model is created implicitly if missing, otherwise existing model is fetched. Model
    name is controlled by the `name` parameter.

    Model Version returned by this method is resolved based on model configuration:
    - If there is an existing model version leftover from the previous failed run with
    `delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
    leftover model version will be reused.
    - Otherwise if `create_new_model_version` is True, a new model version is created.
    - If `create_new_model_version` is False a model version will be fetched based on the version:
        - If `version` is not set, the latest model version will be fetched.
        - If `version` is set to a string, the model version with the matching version will be fetched.
        - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

    Returns:
        The model version based on configuration.
    """
    model = self.get_or_create_model()
    if self.create_new_model_version:
        mv = self._create_model_version(model)
    else:
        mv = self._get_model_version()
    return mv