Skip to content

Model

zenml.model special

Concepts related to the Model Control Plane feature.

lazy_load

Model Version Data Lazy Loader definition.

ModelVersionDataLazyLoader (BaseModel)

Model Version Data Lazy Loader helper class.

It helps the inner codes to fetch proper artifact, model version metadata or artifact metadata from the model version during runtime time of the step.

Source code in zenml/model/lazy_load.py
class ModelVersionDataLazyLoader(BaseModel):
    """Model Version Data Lazy Loader helper class.

    It helps the inner codes to fetch proper artifact,
    model version metadata or artifact metadata from the
    model version during runtime time of the step.
    """

    model: Model
    artifact_name: Optional[str] = None
    artifact_version: Optional[str] = None
    metadata_name: Optional[str] = None

model

Model user facing interface to pass into pipeline or step.

Model (BaseModel)

Model class to pass into pipeline or step to set it into a model context.

name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The version name, version number or stage is optional and points model context" to a specific version/stage. If skipped new version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack. !!! model_version_id "The ID of a specific Model Version, if given - it will override" name and version settings. Used mostly internally.

Source code in zenml/model/model.py
class Model(BaseModel):
    """Model class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    license: The license under which the model is created.
    description: The description of the model.
    audience: The target audience of the model.
    use_cases: The use cases of the model.
    limitations: The known limitations of the model.
    trade_offs: The tradeoffs of the model.
    ethics: The ethical implications of the model.
    tags: Tags associated with the model.
    version: The version name, version number or stage is optional and points model context
        to a specific version/stage. If skipped new version will be created.
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    model_version_id: The ID of a specific Model Version, if given - it will override
        `name` and `version` settings. Used mostly internally.
    """

    name: str
    license: Optional[str] = None
    description: Optional[str] = None
    audience: Optional[str] = None
    use_cases: Optional[str] = None
    limitations: Optional[str] = None
    trade_offs: Optional[str] = None
    ethics: Optional[str] = None
    tags: Optional[List[str]] = None
    version: Optional[Union[ModelStages, int, str]] = Field(
        default=None, union_mode="smart"
    )
    save_models_to_registry: bool = True
    model_version_id: Optional[UUID] = None

    suppress_class_validation_warnings: bool = False
    was_created_in_this_run: bool = False

    _model_id: UUID = PrivateAttr(None)
    _number: int = PrivateAttr(None)

    # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
    #  fields defined under base models. If not handled, this raises a warning.
    #  It is possible to suppress this warning message with the following
    #  configuration, however the ultimate solution is to rename these fields.
    #  Even though they do not cause any problems right now, if we are not
    #  careful we might overwrite some fields protected by pydantic.
    model_config = ConfigDict(protected_namespaces=())

    #########################
    #    Public methods     #
    #########################
    @property
    def id(self) -> UUID:
        """Get version id from the Model Control Plane.

        Returns:
            ID of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).

        Raises:
            RuntimeError: if model version doesn't exist and
                cannot be fetched from the Model Control Plane.
        """
        if self.model_version_id is None:
            try:
                mv = self._get_or_create_model_version()
                self.model_version_id = mv.id
            except RuntimeError as e:
                raise RuntimeError(
                    f"Version `{self.version}` of `{self.name}` model doesn't "
                    "exist and cannot be fetched from the Model Control Plane."
                ) from e
        return self.model_version_id

    @property
    def model_id(self) -> UUID:
        """Get model id from the Model Control Plane.

        Returns:
            The UUID of the model containing this model version.
        """
        if self._model_id is None:
            self._get_or_create_model()
        return self._model_id

    @property
    def number(self) -> int:
        """Get version number from  the Model Control Plane.

        Returns:
            Number of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        if self._number is None:
            try:
                self._get_or_create_model_version()
            except RuntimeError:
                logger.info(
                    f"Version `{self.version}` of `{self.name}` model doesn't "
                    "exist and cannot be fetched from the Model Control Plane."
                )
        return self._number

    @property
    def stage(self) -> Optional[ModelStages]:
        """Get version stage from  the Model Control Plane.

        Returns:
            Stage of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        try:
            stage = self._get_or_create_model_version().stage
            if stage:
                return ModelStages(stage)
        except RuntimeError:
            logger.info(
                f"Version `{self.version}` of `{self.name}` model doesn't "
                "exist and cannot be fetched from the Model Control Plane."
            )
        return None

    def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
        """Load artifact from the Model Control Plane.

        Args:
            name: Name of the artifact to load.
            version: Version of the artifact to load.

        Returns:
            The loaded artifact.

        Raises:
            ValueError: if the model version is not linked to any artifact with
                the given name and version.
        """
        from zenml.artifacts.utils import load_artifact
        from zenml.models import ArtifactVersionResponse

        artifact = self.get_artifact(name=name, version=version)

        if not isinstance(artifact, ArtifactVersionResponse):
            raise ValueError(
                f"Version {self.version} of model {self.name} does not have "
                f"an artifact with name {name} and version {version}."
            )

        return load_artifact(artifact.id, str(artifact.version))

    def get_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the artifact linked to this model version.

        Args:
            name: The name of the artifact to retrieve.
            version: The version of the artifact to retrieve (None for
                latest/non-versioned)

        Returns:
            Specific version of the artifact or placeholder in the design time
                of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_artifact(
            name=name,
            version=version,
        )

    def get_model_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the model artifact linked to this model version.

        Args:
            name: The name of the model artifact to retrieve.
            version: The version of the model artifact to retrieve (None for
                latest/non-versioned)

        Returns:
            Specific version of the model artifact or placeholder in the design
                time of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_model_artifact(
            name=name,
            version=version,
        )

    def get_data_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the data artifact linked to this model version.

        Args:
            name: The name of the data artifact to retrieve.
            version: The version of the data artifact to retrieve (None for
                latest/non-versioned)

        Returns:
            Specific version of the data artifact or placeholder in the design
            time of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_data_artifact(
            name=name,
            version=version,
        )

    def get_deployment_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the deployment artifact linked to this model version.

        Args:
            name: The name of the deployment artifact to retrieve.
            version: The version of the deployment artifact to retrieve (None
                for latest/non-versioned)

        Returns:
            Specific version of the deployment artifact or placeholder in the
            design time of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_deployment_artifact(
            name=name,
            version=version,
        )

    def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
        """Get pipeline run linked to this version.

        Args:
            name: The name of the pipeline run to retrieve.

        Returns:
            PipelineRun as PipelineRunResponse
        """
        return self._get_or_create_model_version().get_pipeline_run(name=name)

    def set_stage(
        self, stage: Union[str, ModelStages], force: bool = False
    ) -> None:
        """Sets this Model to a desired stage.

        Args:
            stage: the target stage for model version.
            force: whether to force archiving of current model version in
                target stage or raise.
        """
        self._get_or_create_model_version().set_stage(stage=stage, force=force)

    def log_metadata(
        self,
        metadata: Dict[str, "MetadataType"],
    ) -> None:
        """Log model version metadata.

        This function can be used to log metadata for current model version.

        Args:
            metadata: The metadata to log.
        """
        from zenml.client import Client

        response = self._get_or_create_model_version()
        Client().create_run_metadata(
            metadata=metadata,
            resource_id=response.id,
            resource_type=MetadataResourceTypes.MODEL_VERSION,
        )

    @property
    def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
        """Get model version run metadata.

        Returns:
            The model version run metadata.

        Raises:
            RuntimeError: If the model version run metadata cannot be fetched.
        """
        from zenml.metadata.lazy_load import RunMetadataLazyGetter
        from zenml.new.pipelines.pipeline_context import (
            get_pipeline_context,
        )

        try:
            get_pipeline_context()
            # avoid exposing too much of internal details by keeping the return type
            return RunMetadataLazyGetter(  # type: ignore[return-value]
                self,
                None,
                None,
            )
        except RuntimeError:
            pass

        response = self._get_or_create_model_version(hydrate=True)
        if response.run_metadata is None:
            raise RuntimeError(
                "Failed to fetch metadata of this model version."
            )
        return response.run_metadata

    # TODO: deprecate me
    @property
    def metadata(self) -> Dict[str, "MetadataType"]:
        """DEPRECATED, use `run_metadata` instead.

        Returns:
            The model version run metadata.
        """
        logger.warning(
            "Model `metadata` property is deprecated. Please use "
            "`run_metadata` instead."
        )
        return {k: v.value for k, v in self.run_metadata.items()}

    def delete_artifact(
        self,
        name: str,
        version: Optional[str] = None,
        only_link: bool = True,
        delete_metadata: bool = True,
        delete_from_artifact_store: bool = False,
    ) -> None:
        """Delete the artifact linked to this model version.

        Args:
            name: The name of the artifact to delete.
            version: The version of the artifact to delete (None for
                latest/non-versioned)
            only_link: Whether to only delete the link to the artifact.
            delete_metadata: Whether to delete the metadata of the artifact.
            delete_from_artifact_store: Whether to delete the artifact from the
                artifact store.
        """
        from zenml.client import Client
        from zenml.models import ArtifactVersionResponse

        artifact_version = self.get_artifact(name, version)
        if isinstance(artifact_version, ArtifactVersionResponse):
            client = Client()
            client.delete_model_version_artifact_link(
                model_version_id=self.id,
                artifact_version_id=artifact_version.id,
            )
            if not only_link:
                client.delete_artifact_version(
                    name_id_or_prefix=artifact_version.id,
                    delete_metadata=delete_metadata,
                    delete_from_artifact_store=delete_from_artifact_store,
                )

    def delete_all_artifacts(
        self,
        only_link: bool = True,
        delete_from_artifact_store: bool = False,
    ) -> None:
        """Delete all artifacts linked to this model version.

        Args:
            only_link: Whether to only delete the link to the artifact.
            delete_from_artifact_store: Whether to delete the artifact from
                the artifact store.
        """
        from zenml.client import Client

        client = Client()

        if not only_link and delete_from_artifact_store:
            mv = self._get_model_version()
            artifact_responses = mv.data_artifacts
            artifact_responses.update(mv.model_artifacts)
            artifact_responses.update(mv.deployment_artifacts)

            for artifact_ in artifact_responses.values():
                for artifact_response_ in artifact_.values():
                    client._delete_artifact_from_artifact_store(
                        artifact_version=artifact_response_
                    )

        client.delete_all_model_version_artifact_links(self.id, only_link)

    def _lazy_artifact_get(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        from zenml import get_pipeline_context
        from zenml.models.v2.core.artifact_version import (
            LazyArtifactVersionResponse,
        )

        try:
            get_pipeline_context()
            return LazyArtifactVersionResponse(
                lazy_load_name=name,
                lazy_load_version=version,
                lazy_load_model=Model(
                    name=self.name, version=self.version or self.number
                ),
            )
        except RuntimeError:
            pass

        return None

    def __eq__(self, other: object) -> bool:
        """Check two Models for equality.

        Args:
            other: object to compare with

        Returns:
            True, if equal, False otherwise.
        """
        if not isinstance(other, Model):
            return NotImplemented
        if self.name != other.name:
            return False
        if self.name == other.name and self.version == other.version:
            return True
        self_mv = self._get_or_create_model_version()
        other_mv = other._get_or_create_model_version()
        return self_mv.id == other_mv.id

    @model_validator(mode="before")
    @classmethod
    @before_validator_handler
    def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            data: Dict of values.

        Returns:
            Dict of validated values.
        """
        suppress_class_validation_warnings = (
            data.get(
                "suppress_class_validation_warnings",
                False,
            )
            or data.get("model_version_id", None) is not None
        )
        version = data.get("version", None)

        if (
            version in [stage.value for stage in ModelStages]
            and not suppress_class_validation_warnings
        ):
            logger.info(
                f"Version `{version}` matches one of the possible "
                "`ModelStages` and will be fetched using stage."
            )
        if str(version).isnumeric() and not suppress_class_validation_warnings:
            logger.info(
                f"`version` `{version}` is numeric and will be fetched "
                "using version number."
            )
        data["suppress_class_validation_warnings"] = True
        return data

    def _validate_config_in_runtime(self) -> "ModelVersionResponse":
        """Validate that config doesn't conflict with runtime environment.

        Returns:
            The model version based on configuration.
        """
        return self._get_or_create_model_version()

    def _get_or_create_model(self) -> "ModelResponse":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        from zenml.client import Client
        from zenml.models import ModelRequest

        zenml_client = Client()
        if self.model_version_id:
            mv = zenml_client.get_model_version(
                model_version_name_or_number_or_id=self.model_version_id,
            )
            model = mv.model
        else:
            try:
                model = zenml_client.zen_store.get_model(
                    model_name_or_id=self.name
                )
            except KeyError:
                model_request = ModelRequest(
                    name=self.name,
                    license=self.license,
                    description=self.description,
                    audience=self.audience,
                    use_cases=self.use_cases,
                    limitations=self.limitations,
                    trade_offs=self.trade_offs,
                    ethics=self.ethics,
                    tags=self.tags,
                    user=zenml_client.active_user.id,
                    workspace=zenml_client.active_workspace.id,
                    save_models_to_registry=self.save_models_to_registry,
                )
                model_request = ModelRequest.model_validate(model_request)
                try:
                    model = zenml_client.zen_store.create_model(
                        model=model_request
                    )
                    logger.info(
                        f"New model `{self.name}` was created implicitly."
                    )
                except EntityExistsError:
                    model = zenml_client.zen_store.get_model(
                        model_name_or_id=self.name
                    )

        self._model_id = model.id
        return model

    def _get_model_version(
        self, hydrate: bool = True
    ) -> "ModelVersionResponse":
        """This method gets a model version from Model Control Plane.

        Args:
            hydrate: Flag deciding whether to hydrate the output model(s)
                by including metadata fields in the response.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client

        zenml_client = Client()
        if self.model_version_id:
            mv = zenml_client.get_model_version(
                model_version_name_or_number_or_id=self.model_version_id,
                hydrate=hydrate,
            )
        else:
            mv = zenml_client.get_model_version(
                model_name_or_id=self.name,
                model_version_name_or_number_or_id=self.version,
                hydrate=hydrate,
            )
            self.model_version_id = mv.id

        difference: Dict[str, Any] = {}
        if mv.metadata:
            if self.description and mv.description != self.description:
                difference["description"] = {
                    "config": self.description,
                    "db": mv.description,
                }
        if self.tags:
            configured_tags = set(self.tags)
            db_tags = {t.name for t in mv.tags}
            if db_tags != configured_tags:
                difference["tags added"] = list(configured_tags - db_tags)
                difference["tags removed"] = list(db_tags - configured_tags)
        if difference:
            logger.warning(
                "Provided model version configuration does not match existing model "
                f"version `{self.name}::{self.version}` with the following "
                f"changes: {difference}. If you want to update the model version "
                "configuration, please use the `zenml model version update` command."
            )

        return mv

    def _get_or_create_model_version(
        self, hydrate: bool = False
    ) -> "ModelVersionResponse":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model
        is fetched. Model name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model version:
        - If `version` is None, a new model version is created, if not created
            by other steps in same run.
        - If `version` is not None a model version will be fetched based on the
            version:
            - If `version` is set to an integer or digit string, the model
                version with the matching number will be fetched.
            - If `version` is set to a string, the model version with the
                matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the
                matching stage will be fetched.

        Args:
            hydrate: Whether to return a hydrated version of the model version.

        Returns:
            The model version based on configuration.

        Raises:
            RuntimeError: if the model version needs to be created, but
                provided name is reserved.
            RuntimeError: if the model version cannot be created.
        """
        from zenml.client import Client
        from zenml.models import ModelVersionRequest

        model = self._get_or_create_model()

        # backup logic, if the Model class is used directly from the code
        if isinstance(self.version, str):
            self.version = format_name_template(self.version)

        zenml_client = Client()
        model_version_request = ModelVersionRequest(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=str(self.version) if self.version else None,
            description=self.description,
            model=model.id,
            tags=self.tags,
        )
        mv_request = ModelVersionRequest.model_validate(model_version_request)
        try:
            if not self.version:
                try:
                    from zenml import get_step_context

                    context = get_step_context()
                except RuntimeError:
                    pass
                else:
                    # if inside a step context we loop over all
                    # model version configuration to find, if the
                    # model version for current model was already
                    # created in the current run, not to create
                    # new model versions
                    pipeline_mv = context.pipeline_run.config.model
                    if (
                        pipeline_mv
                        and pipeline_mv.was_created_in_this_run
                        and pipeline_mv.name == self.name
                        and pipeline_mv.version is not None
                    ):
                        self.version = pipeline_mv.version
                        self.model_version_id = pipeline_mv.model_version_id
                    else:
                        for step in context.pipeline_run.steps.values():
                            step_mv = step.config.model
                            if (
                                step_mv
                                and step_mv.was_created_in_this_run
                                and step_mv.name == self.name
                                and step_mv.version is not None
                            ):
                                self.version = step_mv.version
                                self.model_version_id = (
                                    step_mv.model_version_id
                                )
                                break
            if self.version or self.model_version_id:
                model_version = self._get_model_version()
            else:
                raise KeyError
        except KeyError:
            if (
                self.version
                and str(self.version).lower() in ModelStages.values()
            ):
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "it matches one of the possible model version stages. If you "
                    "are aiming to fetch model version by stage, check if the "
                    "model version in given stage exists. It might be missing, if "
                    "the pipeline promoting model version to this stage failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list -n {self.name}` CLI command."
                )
            if str(self.version).isnumeric():
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "numeric model version names are reserved. If you "
                    "are aiming to fetch model version by number, check if the "
                    "model version with given number exists. It might be missing, if "
                    "the pipeline creating model version failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list -n {self.name}` CLI command."
                )
            retries_made = 0
            for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
                try:
                    model_version = (
                        zenml_client.zen_store.create_model_version(
                            model_version=mv_request
                        )
                    )
                    break
                except EntityExistsError as e:
                    if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
                        raise RuntimeError(
                            f"Failed to create model version "
                            f"`{self.version if self.version else 'new'}` "
                            f"in model `{self.name}`. Retried {retries_made} times. "
                            "This could be driven by exceptionally high concurrency of "
                            "pipeline runs. Please, reach out to us on ZenML Slack for support."
                        ) from e
                    # smoothed exponential back-off, it will go as 0.2, 0.3,
                    # 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
                    sleep = 0.2 * 1.5**i
                    logger.debug(
                        f"Failed to create new model version for "
                        f"model `{self.name}`. Retrying in {sleep}..."
                    )
                    time.sleep(sleep)
                    retries_made += 1
            self.version = model_version.name
            self.was_created_in_this_run = True

            logger.info(f"New model version `{self.version}` was created.")

        self.model_version_id = model_version.id
        self._model_id = model_version.model.id
        self._number = model_version.number
        return model_version

    def _merge(self, model: "Model") -> None:
        self.license = self.license or model.license
        self.description = self.description or model.description
        self.audience = self.audience or model.audience
        self.use_cases = self.use_cases or model.use_cases
        self.limitations = self.limitations or model.limitations
        self.trade_offs = self.trade_offs or model.trade_offs
        self.ethics = self.ethics or model.ethics
        if model.tags is not None:
            self.tags = list(
                {t for t in self.tags or []}.union(set(model.tags))
            )

    def __hash__(self) -> int:
        """Get hash of the `Model`.

        Returns:
            Hash function results
        """
        return hash(
            "::".join(
                (
                    str(v)
                    for v in (
                        self.name,
                        self.version,
                    )
                )
            )
        )

    def _prepare_model_version_before_step_launch(
        self,
        pipeline_run: "PipelineRunResponse",
        step_run: Optional["StepRunResponse"],
        return_logs: bool,
    ) -> str:
        """Prepares model version inside pipeline run.

        Args:
            pipeline_run: pipeline run
            step_run: step run (passed only if model version is defined in a step explicitly)
            return_logs: whether to return logs or not

        Returns:
            Logs related to the Dashboard URL to show later.
        """
        from zenml.client import Client
        from zenml.models import PipelineRunUpdate, StepRunUpdate

        logs = ""

        # copy Model instance to prevent corrupting configs of the
        # subsequent runs, if they share the same config object
        self_copy = self.model_copy()

        # in case request is within the step and no self-configuration is provided
        # try reuse what's in the pipeline run first
        if step_run is None and pipeline_run.model_version is not None:
            self_copy.version = pipeline_run.model_version.name
            self_copy.model_version_id = pipeline_run.model_version.id
        # otherwise try to fill the templated name, if needed
        elif isinstance(self_copy.version, str):
            if pipeline_run.start_time:
                start_time = pipeline_run.start_time
            else:
                start_time = datetime.datetime.now(datetime.timezone.utc)
            self_copy.version = format_name_template(
                self_copy.version,
                date=start_time.strftime("%Y_%m_%d"),
                time=start_time.strftime("%H_%M_%S_%f"),
            )

        # if exact model not yet defined - try to get/create and update it
        # back to the run accordingly
        if self_copy.model_version_id is None:
            model_version_response = self_copy._get_or_create_model_version()

            # update the configured model version id in runs accordingly
            if step_run:
                Client().zen_store.update_run_step(
                    step_run_id=step_run.id,
                    step_run_update=StepRunUpdate(
                        model_version_id=model_version_response.id
                    ),
                )
            else:
                Client().zen_store.update_run(
                    run_id=pipeline_run.id,
                    run_update=PipelineRunUpdate(
                        model_version_id=model_version_response.id
                    ),
                )

            if return_logs:
                from zenml.utils.cloud_utils import try_get_model_version_url

                if logs_to_show := try_get_model_version_url(
                    model_version_response
                ):
                    logs = logs_to_show
                else:
                    logs = (
                        "Models can be viewed in the dashboard using ZenML Pro. Sign up "
                        "for a free trial at https://www.zenml.io/pro/"
                    )
        self.model_version_id = self_copy.model_version_id
        return logs
id: UUID property readonly

Get version id from the Model Control Plane.

Returns:

Type Description
UUID

ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

Exceptions:

Type Description
RuntimeError

if model version doesn't exist and cannot be fetched from the Model Control Plane.

metadata: Dict[str, MetadataType] property readonly

DEPRECATED, use run_metadata instead.

Returns:

Type Description
Dict[str, MetadataType]

The model version run metadata.

model_id: UUID property readonly

Get model id from the Model Control Plane.

Returns:

Type Description
UUID

The UUID of the model containing this model version.

number: int property readonly

Get version number from the Model Control Plane.

Returns:

Type Description
int

Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

run_metadata: Dict[str, RunMetadataResponse] property readonly

Get model version run metadata.

Returns:

Type Description
Dict[str, RunMetadataResponse]

The model version run metadata.

Exceptions:

Type Description
RuntimeError

If the model version run metadata cannot be fetched.

stage: Optional[zenml.enums.ModelStages] property readonly

Get version stage from the Model Control Plane.

Returns:

Type Description
Optional[zenml.enums.ModelStages]

Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

__eq__(self, other) special

Check two Models for equality.

Parameters:

Name Type Description Default
other object

object to compare with

required

Returns:

Type Description
bool

True, if equal, False otherwise.

Source code in zenml/model/model.py
def __eq__(self, other: object) -> bool:
    """Check two Models for equality.

    Args:
        other: object to compare with

    Returns:
        True, if equal, False otherwise.
    """
    if not isinstance(other, Model):
        return NotImplemented
    if self.name != other.name:
        return False
    if self.name == other.name and self.version == other.version:
        return True
    self_mv = self._get_or_create_model_version()
    other_mv = other._get_or_create_model_version()
    return self_mv.id == other_mv.id
__hash__(self) special

Get hash of the Model.

Returns:

Type Description
int

Hash function results

Source code in zenml/model/model.py
def __hash__(self) -> int:
    """Get hash of the `Model`.

    Returns:
        Hash function results
    """
    return hash(
        "::".join(
            (
                str(v)
                for v in (
                    self.name,
                    self.version,
                )
            )
        )
    )
delete_all_artifacts(self, only_link=True, delete_from_artifact_store=False)

Delete all artifacts linked to this model version.

Parameters:

Name Type Description Default
only_link bool

Whether to only delete the link to the artifact.

True
delete_from_artifact_store bool

Whether to delete the artifact from the artifact store.

False
Source code in zenml/model/model.py
def delete_all_artifacts(
    self,
    only_link: bool = True,
    delete_from_artifact_store: bool = False,
) -> None:
    """Delete all artifacts linked to this model version.

    Args:
        only_link: Whether to only delete the link to the artifact.
        delete_from_artifact_store: Whether to delete the artifact from
            the artifact store.
    """
    from zenml.client import Client

    client = Client()

    if not only_link and delete_from_artifact_store:
        mv = self._get_model_version()
        artifact_responses = mv.data_artifacts
        artifact_responses.update(mv.model_artifacts)
        artifact_responses.update(mv.deployment_artifacts)

        for artifact_ in artifact_responses.values():
            for artifact_response_ in artifact_.values():
                client._delete_artifact_from_artifact_store(
                    artifact_version=artifact_response_
                )

    client.delete_all_model_version_artifact_links(self.id, only_link)
delete_artifact(self, name, version=None, only_link=True, delete_metadata=True, delete_from_artifact_store=False)

Delete the artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the artifact to delete.

required
version Optional[str]

The version of the artifact to delete (None for latest/non-versioned)

None
only_link bool

Whether to only delete the link to the artifact.

True
delete_metadata bool

Whether to delete the metadata of the artifact.

True
delete_from_artifact_store bool

Whether to delete the artifact from the artifact store.

False
Source code in zenml/model/model.py
def delete_artifact(
    self,
    name: str,
    version: Optional[str] = None,
    only_link: bool = True,
    delete_metadata: bool = True,
    delete_from_artifact_store: bool = False,
) -> None:
    """Delete the artifact linked to this model version.

    Args:
        name: The name of the artifact to delete.
        version: The version of the artifact to delete (None for
            latest/non-versioned)
        only_link: Whether to only delete the link to the artifact.
        delete_metadata: Whether to delete the metadata of the artifact.
        delete_from_artifact_store: Whether to delete the artifact from the
            artifact store.
    """
    from zenml.client import Client
    from zenml.models import ArtifactVersionResponse

    artifact_version = self.get_artifact(name, version)
    if isinstance(artifact_version, ArtifactVersionResponse):
        client = Client()
        client.delete_model_version_artifact_link(
            model_version_id=self.id,
            artifact_version_id=artifact_version.id,
        )
        if not only_link:
            client.delete_artifact_version(
                name_id_or_prefix=artifact_version.id,
                delete_metadata=delete_metadata,
                delete_from_artifact_store=delete_from_artifact_store,
            )
get_artifact(self, name, version=None)

Get the artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the artifact to retrieve.

required
version Optional[str]

The version of the artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the artifact linked to this model version.

    Args:
        name: The name of the artifact to retrieve.
        version: The version of the artifact to retrieve (None for
            latest/non-versioned)

    Returns:
        Specific version of the artifact or placeholder in the design time
            of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_artifact(
        name=name,
        version=version,
    )
get_data_artifact(self, name, version=None)

Get the data artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the data artifact to retrieve.

required
version Optional[str]

The version of the data artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the data artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_data_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the data artifact linked to this model version.

    Args:
        name: The name of the data artifact to retrieve.
        version: The version of the data artifact to retrieve (None for
            latest/non-versioned)

    Returns:
        Specific version of the data artifact or placeholder in the design
        time of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_data_artifact(
        name=name,
        version=version,
    )
get_deployment_artifact(self, name, version=None)

Get the deployment artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the deployment artifact to retrieve.

required
version Optional[str]

The version of the deployment artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the deployment artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_deployment_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the deployment artifact linked to this model version.

    Args:
        name: The name of the deployment artifact to retrieve.
        version: The version of the deployment artifact to retrieve (None
            for latest/non-versioned)

    Returns:
        Specific version of the deployment artifact or placeholder in the
        design time of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_deployment_artifact(
        name=name,
        version=version,
    )
get_model_artifact(self, name, version=None)

Get the model artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the model artifact to retrieve.

required
version Optional[str]

The version of the model artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the model artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_model_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the model artifact linked to this model version.

    Args:
        name: The name of the model artifact to retrieve.
        version: The version of the model artifact to retrieve (None for
            latest/non-versioned)

    Returns:
        Specific version of the model artifact or placeholder in the design
            time of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_model_artifact(
        name=name,
        version=version,
    )
get_pipeline_run(self, name)

Get pipeline run linked to this version.

Parameters:

Name Type Description Default
name str

The name of the pipeline run to retrieve.

required

Returns:

Type Description
PipelineRunResponse

PipelineRun as PipelineRunResponse

Source code in zenml/model/model.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
    """Get pipeline run linked to this version.

    Args:
        name: The name of the pipeline run to retrieve.

    Returns:
        PipelineRun as PipelineRunResponse
    """
    return self._get_or_create_model_version().get_pipeline_run(name=name)
load_artifact(self, name, version=None)

Load artifact from the Model Control Plane.

Parameters:

Name Type Description Default
name str

Name of the artifact to load.

required
version Optional[str]

Version of the artifact to load.

None

Returns:

Type Description
Any

The loaded artifact.

Exceptions:

Type Description
ValueError

if the model version is not linked to any artifact with the given name and version.

Source code in zenml/model/model.py
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
    """Load artifact from the Model Control Plane.

    Args:
        name: Name of the artifact to load.
        version: Version of the artifact to load.

    Returns:
        The loaded artifact.

    Raises:
        ValueError: if the model version is not linked to any artifact with
            the given name and version.
    """
    from zenml.artifacts.utils import load_artifact
    from zenml.models import ArtifactVersionResponse

    artifact = self.get_artifact(name=name, version=version)

    if not isinstance(artifact, ArtifactVersionResponse):
        raise ValueError(
            f"Version {self.version} of model {self.name} does not have "
            f"an artifact with name {name} and version {version}."
        )

    return load_artifact(artifact.id, str(artifact.version))
log_metadata(self, metadata)

Log model version metadata.

This function can be used to log metadata for current model version.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
Source code in zenml/model/model.py
def log_metadata(
    self,
    metadata: Dict[str, "MetadataType"],
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for current model version.

    Args:
        metadata: The metadata to log.
    """
    from zenml.client import Client

    response = self._get_or_create_model_version()
    Client().create_run_metadata(
        metadata=metadata,
        resource_id=response.id,
        resource_type=MetadataResourceTypes.MODEL_VERSION,
    )
model_post_init(self, __context)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Parameters:

Name Type Description Default
self BaseModel

The BaseModel instance.

required
__context Any

The context.

required
Source code in zenml/model/model.py
def init_private_attributes(self: BaseModel, __context: Any) -> None:
    """This function is meant to behave like a BaseModel method to initialise private attributes.

    It takes context as an argument since that's what pydantic-core passes when calling it.

    Args:
        self: The BaseModel instance.
        __context: The context.
    """
    if getattr(self, '__pydantic_private__', None) is None:
        pydantic_private = {}
        for name, private_attr in self.__private_attributes__.items():
            default = private_attr.get_default()
            if default is not PydanticUndefined:
                pydantic_private[name] = default
        object_setattr(self, '__pydantic_private__', pydantic_private)
set_stage(self, stage, force=False)

Sets this Model to a desired stage.

Parameters:

Name Type Description Default
stage Union[str, zenml.enums.ModelStages]

the target stage for model version.

required
force bool

whether to force archiving of current model version in target stage or raise.

False
Source code in zenml/model/model.py
def set_stage(
    self, stage: Union[str, ModelStages], force: bool = False
) -> None:
    """Sets this Model to a desired stage.

    Args:
        stage: the target stage for model version.
        force: whether to force archiving of current model version in
            target stage or raise.
    """
    self._get_or_create_model_version().set_stage(stage=stage, force=force)

model_version

DEPRECATED, use from zenml import Model instead.

ModelVersion (Model)

DEPRECATED, use from zenml import Model instead.

Source code in zenml/model/model_version.py
class ModelVersion(Model):
    """DEPRECATED, use `from zenml import Model` instead."""

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """DEPRECATED, use `from zenml import Model` instead.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        logger.warning(
            "`ModelVersion` is deprecated. Please use `Model` instead."
        )
        super().__init__(*args, **kwargs)
__init__(self, *args, **kwargs) special

DEPRECATED, use from zenml import Model instead.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in zenml/model/model_version.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """DEPRECATED, use `from zenml import Model` instead.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    logger.warning(
        "`ModelVersion` is deprecated. Please use `Model` instead."
    )
    super().__init__(*args, **kwargs)
model_post_init(self, _ModelMetaclass__context)

We need to both initialize private attributes and call the user-defined model_post_init method.

Source code in zenml/model/model_version.py
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
    """We need to both initialize private attributes and call the user-defined model_post_init
    method.
    """
    init_private_attributes(self, __context)
    original_model_post_init(self, __context)

utils

Utility functions for linking step outputs to model versions.

Link an artifact config to its model version.

Parameters:

Name Type Description Default
artifact_config ArtifactConfig

The artifact config to link.

required
artifact_version_id UUID

The ID of the artifact to link.

required
model Optional[Model]

The model version from the step or pipeline context.

None
Source code in zenml/model/utils.py
def link_artifact_config_to_model(
    artifact_config: ArtifactConfig,
    artifact_version_id: UUID,
    model: Optional["Model"] = None,
) -> None:
    """Link an artifact config to its model version.

    Args:
        artifact_config: The artifact config to link.
        artifact_version_id: The ID of the artifact to link.
        model: The model version from the step or pipeline context.
    """
    client = Client()

    # If the artifact config specifies a model itself then always use that
    if artifact_config.model_name is not None:
        from zenml.model.model import Model

        model = Model(
            name=artifact_config.model_name,
            version=artifact_config.model_version,
        )

    if model:
        request = ModelVersionArtifactRequest(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            artifact_version=artifact_version_id,
            model=model.model_id,
            model_version=model.id,
            is_model_artifact=artifact_config.is_model_artifact,
            is_deployment_artifact=artifact_config.is_deployment_artifact,
        )
        client.zen_store.create_model_version_artifact_link(request)

Link the artifact to the model.

Parameters:

Name Type Description Default
artifact_version_id UUID

The ID of the artifact version.

required
model Optional[Model]

The model to link to.

None
is_model_artifact bool

Whether the artifact is a model artifact.

False
is_deployment_artifact bool

Whether the artifact is a deployment artifact.

False

Exceptions:

Type Description
RuntimeError

If called outside of a step.

Source code in zenml/model/utils.py
def link_artifact_to_model(
    artifact_version_id: UUID,
    model: Optional["Model"] = None,
    is_model_artifact: bool = False,
    is_deployment_artifact: bool = False,
) -> None:
    """Link the artifact to the model.

    Args:
        artifact_version_id: The ID of the artifact version.
        model: The model to link to.
        is_model_artifact: Whether the artifact is a model artifact.
        is_deployment_artifact: Whether the artifact is a deployment artifact.

    Raises:
        RuntimeError: If called outside of a step.
    """
    if not model:
        is_issue = False
        try:
            step_context = get_step_context()
            model = step_context.model
        except StepContextError:
            is_issue = True

        if model is None or is_issue:
            raise RuntimeError(
                "`link_artifact_to_model` called without `model` parameter "
                "and configured model context cannot be identified. Consider "
                "passing the `model` explicitly or configuring it in "
                "@step or @pipeline decorator."
            )

    link_artifact_config_to_model(
        artifact_config=ArtifactConfig(
            is_model_artifact=is_model_artifact,
            is_deployment_artifact=is_deployment_artifact,
        ),
        artifact_version_id=artifact_version_id,
        model=model,
    )

Links a service to a model.

Parameters:

Name Type Description Default
service_id UUID

The ID of the service to link to the model.

required
model Optional[Model]

The model to link the service to.

None
model_version_id Optional[uuid.UUID]

The ID of the model version to link the service to.

None

Exceptions:

Type Description
RuntimeError

If no model is provided and the model context cannot be identified.

Source code in zenml/model/utils.py
def link_service_to_model(
    service_id: UUID,
    model: Optional["Model"] = None,
    model_version_id: Optional[UUID] = None,
) -> None:
    """Links a service to a model.

    Args:
        service_id: The ID of the service to link to the model.
        model: The model to link the service to.
        model_version_id: The ID of the model version to link the service to.

    Raises:
        RuntimeError: If no model is provided and the model context cannot be
            identified.
    """
    client = Client()

    # If no model is provided, try to get it from the context
    if not model and not model_version_id:
        is_issue = False
        try:
            step_context = get_step_context()
            model = step_context.model
        except StepContextError:
            is_issue = True

        if model is None or is_issue:
            raise RuntimeError(
                "`link_service_to_model` called without `model` parameter "
                "and configured model context cannot be identified. Consider "
                "passing the `model` explicitly or configuring it in "
                "@step or @pipeline decorator."
            )

    model_version_id = (
        model_version_id or model._get_or_create_model_version().id
        if model
        else None
    )
    update_service = ServiceUpdate(model_version_id=model_version_id)
    client.zen_store.update_service(
        service_id=service_id, update=update_service
    )

Links the output artifacts of a step to the model.

Parameters:

Name Type Description Default
artifact_version_ids Dict[str, uuid.UUID]

The IDs of the published output artifacts.

required

Exceptions:

Type Description
RuntimeError

If called outside of a step.

Source code in zenml/model/utils.py
def link_step_artifacts_to_model(
    artifact_version_ids: Dict[str, UUID],
) -> None:
    """Links the output artifacts of a step to the model.

    Args:
        artifact_version_ids: The IDs of the published output artifacts.

    Raises:
        RuntimeError: If called outside of a step.
    """
    try:
        step_context = get_step_context()
    except StepContextError:
        raise RuntimeError(
            "`link_step_artifacts_to_model` can only be called from within a "
            "step."
        )
    try:
        model = step_context.model
    except StepContextError:
        model = None
        logger.debug("No model context found, unable to auto-link artifacts.")

    for artifact_name, artifact_version_id in artifact_version_ids.items():
        artifact_config = step_context._get_output(
            artifact_name
        ).artifact_config

        # Implicit linking
        if artifact_config is None and model is not None:
            artifact_config = ArtifactConfig(name=artifact_name)
            logger.info(
                f"Implicitly linking artifact `{artifact_name}` to model "
                f"`{model.name}` version `{model.version}`."
            )

        if artifact_config:
            link_artifact_config_to_model(
                artifact_config=artifact_config,
                artifact_version_id=artifact_version_id,
                model=model,
            )

log_model_metadata(metadata, model_name=None, model_version=None)

Log model version metadata.

This function can be used to log metadata for existing model versions.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
model_name Optional[str]

The name of the model to log metadata for. Can be omitted when being called inside a step with configured model in decorator.

None
model_version Union[zenml.enums.ModelStages, str, int]

The version of the model to log metadata for. Can be omitted when being called inside a step with configured model in decorator.

None

Exceptions:

Type Description
ValueError

If no model name/version is provided and the function is not called inside a step with configured model in decorator.

Source code in zenml/model/utils.py
def log_model_metadata(
    metadata: Dict[str, "MetadataType"],
    model_name: Optional[str] = None,
    model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for existing model versions.

    Args:
        metadata: The metadata to log.
        model_name: The name of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model` in decorator.
        model_version: The version of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model` in decorator.

    Raises:
        ValueError: If no model name/version is provided and the function is not
            called inside a step with configured `model` in decorator.
    """
    if model_name and model_version:
        from zenml import Model

        mv = Model(name=model_name, version=model_version)
    else:
        try:
            step_context = get_step_context()
        except RuntimeError:
            raise ValueError(
                "Model name and version must be provided unless the function is "
                "called inside a step with configured `model` in decorator."
            )
        mv = step_context.model

    mv.log_metadata(metadata)

log_model_version_metadata(metadata, model_name=None, model_version=None)

Log model version metadata.

This function can be used to log metadata for existing model versions.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
model_name Optional[str]

The name of the model to log metadata for. Can be omitted when being called inside a step with configured model in decorator.

None
model_version Union[zenml.enums.ModelStages, str, int]

The version of the model to log metadata for. Can be omitted when being called inside a step with configured model in decorator.

None
Source code in zenml/model/utils.py
def log_model_version_metadata(
    metadata: Dict[str, "MetadataType"],
    model_name: Optional[str] = None,
    model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for existing model versions.

    Args:
        metadata: The metadata to log.
        model_name: The name of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model` in decorator.
        model_version: The version of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model` in decorator.
    """
    logger.warning(
        "`log_model_version_metadata` is deprecated. Please use "
        "`log_model_metadata` instead."
    )
    log_model_metadata(
        metadata=metadata, model_name=model_name, model_version=model_version
    )