Skip to content

Model

zenml.model special

Concepts related to the Model Control Plane feature.

lazy_load

Model Version Data Lazy Loader definition.

ModelVersionDataLazyLoader (BaseModel)

Model Version Data Lazy Loader helper class.

It helps the inner codes to fetch proper artifact, model version metadata or artifact metadata from the model version during runtime time of the step.

Source code in zenml/model/lazy_load.py
class ModelVersionDataLazyLoader(BaseModel):
    """Model Version Data Lazy Loader helper class.

    It helps the inner codes to fetch proper artifact,
    model version metadata or artifact metadata from the
    model version during runtime time of the step.
    """

    model_name: str
    model_version: Optional[str] = None
    artifact_name: Optional[str] = None
    artifact_version: Optional[str] = None
    metadata_name: Optional[str] = None

    # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
    #  fields defined under base models. If not handled, this raises a warning.
    #  It is possible to suppress this warning message with the following
    #  configuration, however the ultimate solution is to rename these fields.
    #  Even though they do not cause any problems right now, if we are not
    #  careful we might overwrite some fields protected by pydantic.
    model_config = ConfigDict(protected_namespaces=())

    @model_validator(mode="before")
    @classmethod
    @before_validator_handler
    def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            data: Dict of values.

        Returns:
            Dict of validated values.

        Raises:
            ValueError: If the model version id, but call is not internal.
        """
        if data.get("model_version", None) is None:
            try:
                context = get_pipeline_context()
                if (
                    not context.model
                    or context.model.name != data["model_name"]
                ):
                    raise ValueError(
                        "`version` must be set if you use the `Model` class "
                        "directly in the pipeline body, otherwise, you can use "
                        "`get_pipeline_context().model` to lazy load the current "
                        "Model Version from the pipeline context."
                    )
            except RuntimeError:
                pass
        data["suppress_class_validation_warnings"] = True
        return data

    def _get_model_response(
        self, pipeline_run: "PipelineRunResponse"
    ) -> "ModelVersionResponse":
        # if the version/number is None -> return the model in context
        if self.model_version is None:
            if mv := pipeline_run.model_version:
                if mv.model.name != self.model_name:
                    raise RuntimeError(
                        "Lazy loading of the model failed, since given name "
                        f"`{self.model_name}` does not match the model name "
                        f"in the pipeline context: `{mv.model.name}`."
                    )
                return mv
            else:
                raise RuntimeError(
                    "Lazy loading of the model failed, since the model version "
                    "is not set in the pipeline context."
                )

        # else return the model version by version
        else:
            from zenml.client import Client

            try:
                return Client().get_model_version(
                    model_name_or_id=self.model_name,
                    model_version_name_or_number_or_id=self.model_version,
                )
            except KeyError as e:
                raise RuntimeError(
                    "Lazy loading of the model version failed: "
                    f"no model `{self.model_name}` with version "
                    f"`{self.model_version}` could be found."
                ) from e

model

Model user facing interface to pass into pipeline or step.

Model (BaseModel)

Model class to pass into pipeline or step to set it into a model context.

name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The version name, version number or stage is optional and points model context" to a specific version/stage. If skipped new version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack.

Source code in zenml/model/model.py
class Model(BaseModel):
    """Model class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    license: The license under which the model is created.
    description: The description of the model.
    audience: The target audience of the model.
    use_cases: The use cases of the model.
    limitations: The known limitations of the model.
    trade_offs: The tradeoffs of the model.
    ethics: The ethical implications of the model.
    tags: Tags associated with the model.
    version: The version name, version number or stage is optional and points model context
        to a specific version/stage. If skipped new version will be created.
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    """

    name: str
    license: Optional[str] = None
    description: Optional[str] = None
    audience: Optional[str] = None
    use_cases: Optional[str] = None
    limitations: Optional[str] = None
    trade_offs: Optional[str] = None
    ethics: Optional[str] = None
    tags: Optional[List[str]] = None
    version: Optional[Union[ModelStages, int, str]] = Field(
        default=None, union_mode="smart"
    )
    save_models_to_registry: bool = True

    # technical attributes
    model_version_id: Optional[UUID] = None
    suppress_class_validation_warnings: bool = False
    was_created_in_this_run: bool = False
    _model_id: UUID = PrivateAttr(None)
    _number: Optional[int] = PrivateAttr(None)
    _created_model_version: bool = PrivateAttr(False)

    # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
    #  fields defined under base models. If not handled, this raises a warning.
    #  It is possible to suppress this warning message with the following
    #  configuration, however the ultimate solution is to rename these fields.
    #  Even though they do not cause any problems right now, if we are not
    #  careful we might overwrite some fields protected by pydantic.
    model_config = ConfigDict(protected_namespaces=())

    #########################
    #    Public methods     #
    #########################
    @property
    def id(self) -> UUID:
        """Get version id from the Model Control Plane.

        Returns:
            ID of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).

        Raises:
            RuntimeError: if model version doesn't exist and
                cannot be fetched from the Model Control Plane.
        """
        if self.model_version_id is None:
            try:
                mv = self._get_or_create_model_version()
                self.model_version_id = mv.id
            except RuntimeError as e:
                raise RuntimeError(
                    f"Version `{self.version}` of `{self.name}` model doesn't "
                    "exist and cannot be fetched from the Model Control Plane."
                ) from e
        return self.model_version_id

    @property
    def model_id(self) -> UUID:
        """Get model id from the Model Control Plane.

        Returns:
            The UUID of the model containing this model version.
        """
        if self._model_id is None:
            self._get_or_create_model()
        return self._model_id

    @property
    def number(self) -> int:
        """Get version number from  the Model Control Plane.

        Returns:
            Number of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).

        Raises:
            KeyError: if model version doesn't exist and
                cannot be fetched from the Model Control Plane.
        """
        if self._number is None:
            try:
                mv = self._get_or_create_model_version()
                self._number = mv.number
            except RuntimeError as e:
                raise KeyError(
                    f"Version `{self.version}` of `{self.name}` model doesn't "
                    "exist and cannot be fetched from the Model Control Plane."
                ) from e
        return self._number

    @property
    def stage(self) -> Optional[ModelStages]:
        """Get version stage from  the Model Control Plane.

        Returns:
            Stage of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        try:
            stage = self._get_or_create_model_version().stage
            if stage:
                return ModelStages(stage)
        except RuntimeError:
            logger.info(
                f"Version `{self.version}` of `{self.name}` model doesn't "
                "exist and cannot be fetched from the Model Control Plane."
            )
        return None

    def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
        """Load artifact from the Model Control Plane.

        Args:
            name: Name of the artifact to load.
            version: Version of the artifact to load.

        Returns:
            The loaded artifact.

        Raises:
            ValueError: if the model version is not linked to any artifact with
                the given name and version.
        """
        from zenml.artifacts.utils import load_artifact
        from zenml.models import ArtifactVersionResponse

        artifact = self.get_artifact(name=name, version=version)

        if not isinstance(artifact, ArtifactVersionResponse):
            raise ValueError(
                f"Version {self.version} of model {self.name} does not have "
                f"an artifact with name {name} and version {version}."
            )

        return load_artifact(artifact.id, str(artifact.version))

    def get_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the artifact linked to this model version.

        Args:
            name: The name of the artifact to retrieve.
            version: The version of the artifact to retrieve (None for
                latest/non-versioned)

        Returns:
            Specific version of the artifact or placeholder in the design time
                of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_artifact(
            name=name,
            version=version,
        )

    def get_model_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the model artifact linked to this model version.

        Args:
            name: The name of the model artifact to retrieve.
            version: The version of the model artifact to retrieve (None for
                latest/non-versioned)

        Returns:
            Specific version of the model artifact or placeholder in the design
                time of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_model_artifact(
            name=name,
            version=version,
        )

    def get_data_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the data artifact linked to this model version.

        Args:
            name: The name of the data artifact to retrieve.
            version: The version of the data artifact to retrieve (None for
                latest/non-versioned)

        Returns:
            Specific version of the data artifact or placeholder in the design
            time of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_data_artifact(
            name=name,
            version=version,
        )

    def get_deployment_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        """Get the deployment artifact linked to this model version.

        Args:
            name: The name of the deployment artifact to retrieve.
            version: The version of the deployment artifact to retrieve (None
                for latest/non-versioned)

        Returns:
            Specific version of the deployment artifact or placeholder in the
            design time of the pipeline.
        """
        if lazy := self._lazy_artifact_get(name, version):
            return lazy

        return self._get_or_create_model_version().get_deployment_artifact(
            name=name,
            version=version,
        )

    def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
        """Get pipeline run linked to this version.

        Args:
            name: The name of the pipeline run to retrieve.

        Returns:
            PipelineRun as PipelineRunResponse
        """
        return self._get_or_create_model_version().get_pipeline_run(name=name)

    def set_stage(
        self, stage: Union[str, ModelStages], force: bool = False
    ) -> None:
        """Sets this Model to a desired stage.

        Args:
            stage: the target stage for model version.
            force: whether to force archiving of current model version in
                target stage or raise.
        """
        self._get_or_create_model_version().set_stage(stage=stage, force=force)

    def log_metadata(
        self,
        metadata: Dict[str, "MetadataType"],
    ) -> None:
        """Log model version metadata.

        This function can be used to log metadata for current model version.

        Args:
            metadata: The metadata to log.
        """
        from zenml.client import Client

        response = self._get_or_create_model_version()
        Client().create_run_metadata(
            metadata=metadata,
            resource_id=response.id,
            resource_type=MetadataResourceTypes.MODEL_VERSION,
        )

    @property
    def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
        """Get model version run metadata.

        Returns:
            The model version run metadata.

        Raises:
            RuntimeError: If the model version run metadata cannot be fetched.
        """
        from zenml.metadata.lazy_load import RunMetadataLazyGetter

        try:
            get_pipeline_context()
            # avoid exposing too much of internal details by keeping the return type
            return RunMetadataLazyGetter(  # type: ignore[return-value]
                self.name,
                self._lazy_version,
            )
        except RuntimeError:
            pass

        response = self._get_or_create_model_version(hydrate=True)
        if response.run_metadata is None:
            raise RuntimeError(
                "Failed to fetch metadata of this model version."
            )
        return response.run_metadata

    def delete_artifact(
        self,
        name: str,
        version: Optional[str] = None,
        only_link: bool = True,
        delete_metadata: bool = True,
        delete_from_artifact_store: bool = False,
    ) -> None:
        """Delete the artifact linked to this model version.

        Args:
            name: The name of the artifact to delete.
            version: The version of the artifact to delete (None for
                latest/non-versioned)
            only_link: Whether to only delete the link to the artifact.
            delete_metadata: Whether to delete the metadata of the artifact.
            delete_from_artifact_store: Whether to delete the artifact from the
                artifact store.
        """
        from zenml.client import Client
        from zenml.models import ArtifactVersionResponse

        artifact_version = self.get_artifact(name, version)
        if isinstance(artifact_version, ArtifactVersionResponse):
            client = Client()
            client.delete_model_version_artifact_link(
                model_version_id=self.id,
                artifact_version_id=artifact_version.id,
            )
            if not only_link:
                client.delete_artifact_version(
                    name_id_or_prefix=artifact_version.id,
                    delete_metadata=delete_metadata,
                    delete_from_artifact_store=delete_from_artifact_store,
                )

    def delete_all_artifacts(
        self,
        only_link: bool = True,
        delete_from_artifact_store: bool = False,
    ) -> None:
        """Delete all artifacts linked to this model version.

        Args:
            only_link: Whether to only delete the link to the artifact.
            delete_from_artifact_store: Whether to delete the artifact from
                the artifact store.
        """
        from zenml.client import Client

        client = Client()

        if not only_link and delete_from_artifact_store:
            mv = self._get_model_version()
            artifact_responses = mv.data_artifacts
            artifact_responses.update(mv.model_artifacts)
            artifact_responses.update(mv.deployment_artifacts)

            for artifact_ in artifact_responses.values():
                for artifact_response_ in artifact_.values():
                    client._delete_artifact_from_artifact_store(
                        artifact_version=artifact_response_
                    )

        client.delete_all_model_version_artifact_links(self.id, only_link)

    def _lazy_artifact_get(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ArtifactVersionResponse"]:
        from zenml.models.v2.core.artifact_version import (
            LazyArtifactVersionResponse,
        )

        try:
            get_pipeline_context()
            return LazyArtifactVersionResponse(
                lazy_load_name=name,
                lazy_load_version=version,
                lazy_load_model_name=self.name,
                lazy_load_model_version=self._lazy_version,
            )
        except RuntimeError:
            pass

        return None

    def __eq__(self, other: object) -> bool:
        """Check two Models for equality.

        Args:
            other: object to compare with

        Returns:
            True, if equal, False otherwise.
        """
        if not isinstance(other, Model):
            return NotImplemented
        if self.name != other.name:
            return False
        if self.name == other.name and self.version == other.version:
            return True
        self_mv = self._get_or_create_model_version()
        other_mv = other._get_or_create_model_version()
        return self_mv.id == other_mv.id

    @model_validator(mode="before")
    @classmethod
    @before_validator_handler
    def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            data: Dict of values.

        Returns:
            Dict of validated values.

        Raises:
            ValueError: If the model version id, but call is not internal.
        """
        suppress_class_validation_warnings = data.get(
            "suppress_class_validation_warnings",
            False,
        )
        if not suppress_class_validation_warnings and data.get(
            "model_version_id", None
        ):
            raise ValueError(
                "`model_version_id` field is for internal use only"
            )

        version = data.get("version", None)

        if (
            version in [stage.value for stage in ModelStages]
            and not suppress_class_validation_warnings
        ):
            logger.info(
                f"Version `{version}` matches one of the possible "
                "`ModelStages` and will be fetched using stage."
            )
        if str(version).isnumeric() and not suppress_class_validation_warnings:
            logger.info(
                f"`version` `{version}` is numeric and will be fetched "
                "using version number."
            )
        data["suppress_class_validation_warnings"] = True
        return data

    def _validate_config_in_runtime(self) -> "ModelVersionResponse":
        """Validate that config doesn't conflict with runtime environment.

        Returns:
            The model version based on configuration.
        """
        return self._get_or_create_model_version()

    def _get_or_create_model(self) -> "ModelResponse":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        from zenml.client import Client
        from zenml.models import ModelRequest

        zenml_client = Client()
        if self.model_version_id:
            mv = zenml_client.get_model_version(
                model_version_name_or_number_or_id=self.model_version_id,
            )
            model = mv.model
        else:
            try:
                model = zenml_client.zen_store.get_model(
                    model_name_or_id=self.name
                )
            except KeyError:
                model_request = ModelRequest(
                    name=self.name,
                    license=self.license,
                    description=self.description,
                    audience=self.audience,
                    use_cases=self.use_cases,
                    limitations=self.limitations,
                    trade_offs=self.trade_offs,
                    ethics=self.ethics,
                    user=zenml_client.active_user.id,
                    workspace=zenml_client.active_workspace.id,
                    save_models_to_registry=self.save_models_to_registry,
                )
                model_request = ModelRequest.model_validate(model_request)
                try:
                    model = zenml_client.zen_store.create_model(
                        model=model_request
                    )
                    logger.info(
                        f"New model `{self.name}` was created implicitly."
                    )
                except EntityExistsError:
                    model = zenml_client.zen_store.get_model(
                        model_name_or_id=self.name
                    )

        self._model_id = model.id
        return model

    def _get_model_version(
        self, hydrate: bool = True
    ) -> "ModelVersionResponse":
        """This method gets a model version from Model Control Plane.

        Args:
            hydrate: Flag deciding whether to hydrate the output model(s)
                by including metadata fields in the response.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client

        zenml_client = Client()
        if self.model_version_id:
            mv = zenml_client.get_model_version(
                model_version_name_or_number_or_id=self.model_version_id,
                hydrate=hydrate,
            )
        else:
            mv = zenml_client.get_model_version(
                model_name_or_id=self.name,
                model_version_name_or_number_or_id=self.version,
                hydrate=hydrate,
            )
            self.model_version_id = mv.id

        difference: Dict[str, Any] = {}
        if mv.metadata:
            if self.description and mv.description != self.description:
                difference["description"] = {
                    "config": self.description,
                    "db": mv.description,
                }
        if self.tags:
            configured_tags = set(self.tags)
            db_tags = {t.name for t in mv.tags}
            if db_tags != configured_tags:
                difference["tags added"] = list(configured_tags - db_tags)
                difference["tags removed"] = list(db_tags - configured_tags)
        if difference:
            logger.warning(
                "Provided model version configuration does not match existing model "
                f"version `{self.name}::{self.version}` with the following "
                f"changes: {difference}. If you want to update the model version "
                "configuration, please use the `zenml model version update` command."
            )

        return mv

    def _get_or_create_model_version(
        self, hydrate: bool = False
    ) -> "ModelVersionResponse":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model
        is fetched. Model name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model version:
        - If `version` is None, a new model version is created, if not created
            by other steps in same run.
        - If `version` is not None a model version will be fetched based on the
            version:
            - If `version` is set to an integer or digit string, the model
                version with the matching number will be fetched.
            - If `version` is set to a string, the model version with the
                matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the
                matching stage will be fetched.

        Args:
            hydrate: Whether to return a hydrated version of the model version.

        Returns:
            The model version based on configuration.

        Raises:
            RuntimeError: if the model version needs to be created, but
                provided name is reserved.
            RuntimeError: if the model version cannot be created.
        """
        from zenml.client import Client
        from zenml.models import ModelVersionRequest

        model = self._get_or_create_model()

        # backup logic, if the Model class is used directly from the code
        if isinstance(self.version, str):
            self.version = format_name_template(self.version)

        zenml_client = Client()
        model_version_request = ModelVersionRequest(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=str(self.version) if self.version else None,
            description=self.description,
            model=model.id,
            tags=self.tags,
        )
        mv_request = ModelVersionRequest.model_validate(model_version_request)
        try:
            if not self.version:
                try:
                    from zenml import get_step_context

                    context = get_step_context()
                except RuntimeError:
                    pass
                else:
                    # if inside a step context we loop over all
                    # model version configuration to find, if the
                    # model version for current model was already
                    # created in the current run, not to create
                    # new model versions
                    pipeline_mv = context.pipeline_run.config.model
                    if (
                        pipeline_mv
                        and pipeline_mv.was_created_in_this_run
                        and pipeline_mv.name == self.name
                        and pipeline_mv.version is not None
                    ):
                        self.version = pipeline_mv.version
                        self.model_version_id = pipeline_mv.model_version_id
                    else:
                        for step in context.pipeline_run.steps.values():
                            step_mv = step.config.model
                            if (
                                step_mv
                                and step_mv.was_created_in_this_run
                                and step_mv.name == self.name
                                and step_mv.version is not None
                            ):
                                self.version = step_mv.version
                                self.model_version_id = (
                                    step_mv.model_version_id
                                )
                                break
            if self.version or self.model_version_id:
                model_version = self._get_model_version()
            else:
                raise KeyError
        except KeyError:
            if (
                self.version
                and str(self.version).lower() in ModelStages.values()
            ):
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "it matches one of the possible model version stages. If you "
                    "are aiming to fetch model version by stage, check if the "
                    "model version in given stage exists. It might be missing, if "
                    "the pipeline promoting model version to this stage failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list -n {self.name}` CLI command."
                )
            if str(self.version).isnumeric():
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "numeric model version names are reserved. If you "
                    "are aiming to fetch model version by number, check if the "
                    "model version with given number exists. It might be missing, if "
                    "the pipeline creating model version failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list -n {self.name}` CLI command."
                )
            retries_made = 0
            for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
                try:
                    model_version = (
                        zenml_client.zen_store.create_model_version(
                            model_version=mv_request
                        )
                    )
                    break
                except EntityExistsError as e:
                    if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
                        raise RuntimeError(
                            f"Failed to create model version "
                            f"`{self.version if self.version else 'new'}` "
                            f"in model `{self.name}`. Retried {retries_made} times. "
                            "This could be driven by exceptionally high concurrency of "
                            "pipeline runs. Please, reach out to us on ZenML Slack for support."
                        ) from e
                    # smoothed exponential back-off, it will go as 0.2, 0.3,
                    # 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
                    sleep = 0.2 * 1.5**i
                    logger.debug(
                        f"Failed to create new model version for "
                        f"model `{self.name}`. Retrying in {sleep}..."
                    )
                    time.sleep(sleep)
                    retries_made += 1
            self.version = model_version.name
            self.was_created_in_this_run = True
            self._created_model_version = True

            logger.info(
                "Created new model version `%s` for model `%s`.",
                self.version,
                self.name,
            )

        self.model_version_id = model_version.id
        self._model_id = model_version.model.id
        self._number = model_version.number
        return model_version

    def _merge(self, model: "Model") -> None:
        self.license = self.license or model.license
        self.description = self.description or model.description
        self.audience = self.audience or model.audience
        self.use_cases = self.use_cases or model.use_cases
        self.limitations = self.limitations or model.limitations
        self.trade_offs = self.trade_offs or model.trade_offs
        self.ethics = self.ethics or model.ethics
        if model.tags is not None:
            self.tags = list(
                {t for t in self.tags or []}.union(set(model.tags))
            )

    def __hash__(self) -> int:
        """Get hash of the `Model`.

        Returns:
            Hash function results
        """
        return hash(
            "::".join(
                (
                    str(v)
                    for v in (
                        self.name,
                        self.version,
                    )
                )
            )
        )

    def _prepare_model_version_before_step_launch(
        self,
        pipeline_run: "PipelineRunResponse",
        step_run: Optional["StepRunResponse"],
        return_logs: bool,
    ) -> Tuple[str, "PipelineRunResponse", Optional["StepRunResponse"]]:
        """Prepares model version inside pipeline run.

        Args:
            pipeline_run: pipeline run
            step_run: step run (passed only if model version is defined in a step explicitly)
            return_logs: whether to return logs or not

        Returns:
            Logs related to the Dashboard URL to show later.
        """
        from zenml.client import Client
        from zenml.models import PipelineRunUpdate, StepRunUpdate

        logs = ""

        # copy Model instance to prevent corrupting configs of the
        # subsequent runs, if they share the same config object
        self_copy = self.model_copy()

        # in case request is within the step and no self-configuration is provided
        # try reuse what's in the pipeline run first
        if step_run is None and pipeline_run.model_version is not None:
            self_copy.version = pipeline_run.model_version.name
            self_copy.model_version_id = pipeline_run.model_version.id
        # otherwise try to fill the templated name, if needed
        elif isinstance(self_copy.version, str):
            if pipeline_run.start_time:
                start_time = pipeline_run.start_time
            else:
                start_time = datetime.datetime.now(datetime.timezone.utc)
            self_copy.version = format_name_template(
                self_copy.version,
                date=start_time.strftime("%Y_%m_%d"),
                time=start_time.strftime("%H_%M_%S_%f"),
            )

        # if exact model not yet defined - try to get/create and update it
        # back to the run accordingly
        if self_copy.model_version_id is None:
            model_version_response = self_copy._get_or_create_model_version()

            client = Client()
            # update the configured model version id in runs accordingly
            if step_run:
                step_run = client.zen_store.update_run_step(
                    step_run_id=step_run.id,
                    step_run_update=StepRunUpdate(
                        model_version_id=model_version_response.id
                    ),
                )
            else:
                pipeline_run = client.zen_store.update_run(
                    run_id=pipeline_run.id,
                    run_update=PipelineRunUpdate(
                        model_version_id=model_version_response.id
                    ),
                )

            if return_logs:
                from zenml.utils.cloud_utils import try_get_model_version_url

                if logs_to_show := try_get_model_version_url(
                    model_version_response
                ):
                    logs = logs_to_show
                else:
                    logs = (
                        "Models can be viewed in the dashboard using ZenML Pro. Sign up "
                        "for a free trial at https://www.zenml.io/pro/"
                    )
        self.model_version_id = self_copy.model_version_id
        return logs, pipeline_run, step_run

    @property
    def _lazy_version(self) -> Optional[str]:
        """Get version name for lazy loader.

        This getter ensures that new model version
        creation is never triggered here.

        Returns:
            Version name or None if it was not set
        """
        if self._number is not None:
            return str(self._number)
        elif self.version is not None:
            if isinstance(self.version, ModelStages):
                return self.version.value
            return str(self.version)
        return None
id: UUID property readonly

Get version id from the Model Control Plane.

Returns:

Type Description
UUID

ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

Exceptions:

Type Description
RuntimeError

if model version doesn't exist and cannot be fetched from the Model Control Plane.

model_id: UUID property readonly

Get model id from the Model Control Plane.

Returns:

Type Description
UUID

The UUID of the model containing this model version.

number: int property readonly

Get version number from the Model Control Plane.

Returns:

Type Description
int

Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

Exceptions:

Type Description
KeyError

if model version doesn't exist and cannot be fetched from the Model Control Plane.

run_metadata: Dict[str, RunMetadataResponse] property readonly

Get model version run metadata.

Returns:

Type Description
Dict[str, RunMetadataResponse]

The model version run metadata.

Exceptions:

Type Description
RuntimeError

If the model version run metadata cannot be fetched.

stage: Optional[zenml.enums.ModelStages] property readonly

Get version stage from the Model Control Plane.

Returns:

Type Description
Optional[zenml.enums.ModelStages]

Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

__eq__(self, other) special

Check two Models for equality.

Parameters:

Name Type Description Default
other object

object to compare with

required

Returns:

Type Description
bool

True, if equal, False otherwise.

Source code in zenml/model/model.py
def __eq__(self, other: object) -> bool:
    """Check two Models for equality.

    Args:
        other: object to compare with

    Returns:
        True, if equal, False otherwise.
    """
    if not isinstance(other, Model):
        return NotImplemented
    if self.name != other.name:
        return False
    if self.name == other.name and self.version == other.version:
        return True
    self_mv = self._get_or_create_model_version()
    other_mv = other._get_or_create_model_version()
    return self_mv.id == other_mv.id
__hash__(self) special

Get hash of the Model.

Returns:

Type Description
int

Hash function results

Source code in zenml/model/model.py
def __hash__(self) -> int:
    """Get hash of the `Model`.

    Returns:
        Hash function results
    """
    return hash(
        "::".join(
            (
                str(v)
                for v in (
                    self.name,
                    self.version,
                )
            )
        )
    )
delete_all_artifacts(self, only_link=True, delete_from_artifact_store=False)

Delete all artifacts linked to this model version.

Parameters:

Name Type Description Default
only_link bool

Whether to only delete the link to the artifact.

True
delete_from_artifact_store bool

Whether to delete the artifact from the artifact store.

False
Source code in zenml/model/model.py
def delete_all_artifacts(
    self,
    only_link: bool = True,
    delete_from_artifact_store: bool = False,
) -> None:
    """Delete all artifacts linked to this model version.

    Args:
        only_link: Whether to only delete the link to the artifact.
        delete_from_artifact_store: Whether to delete the artifact from
            the artifact store.
    """
    from zenml.client import Client

    client = Client()

    if not only_link and delete_from_artifact_store:
        mv = self._get_model_version()
        artifact_responses = mv.data_artifacts
        artifact_responses.update(mv.model_artifacts)
        artifact_responses.update(mv.deployment_artifacts)

        for artifact_ in artifact_responses.values():
            for artifact_response_ in artifact_.values():
                client._delete_artifact_from_artifact_store(
                    artifact_version=artifact_response_
                )

    client.delete_all_model_version_artifact_links(self.id, only_link)
delete_artifact(self, name, version=None, only_link=True, delete_metadata=True, delete_from_artifact_store=False)

Delete the artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the artifact to delete.

required
version Optional[str]

The version of the artifact to delete (None for latest/non-versioned)

None
only_link bool

Whether to only delete the link to the artifact.

True
delete_metadata bool

Whether to delete the metadata of the artifact.

True
delete_from_artifact_store bool

Whether to delete the artifact from the artifact store.

False
Source code in zenml/model/model.py
def delete_artifact(
    self,
    name: str,
    version: Optional[str] = None,
    only_link: bool = True,
    delete_metadata: bool = True,
    delete_from_artifact_store: bool = False,
) -> None:
    """Delete the artifact linked to this model version.

    Args:
        name: The name of the artifact to delete.
        version: The version of the artifact to delete (None for
            latest/non-versioned)
        only_link: Whether to only delete the link to the artifact.
        delete_metadata: Whether to delete the metadata of the artifact.
        delete_from_artifact_store: Whether to delete the artifact from the
            artifact store.
    """
    from zenml.client import Client
    from zenml.models import ArtifactVersionResponse

    artifact_version = self.get_artifact(name, version)
    if isinstance(artifact_version, ArtifactVersionResponse):
        client = Client()
        client.delete_model_version_artifact_link(
            model_version_id=self.id,
            artifact_version_id=artifact_version.id,
        )
        if not only_link:
            client.delete_artifact_version(
                name_id_or_prefix=artifact_version.id,
                delete_metadata=delete_metadata,
                delete_from_artifact_store=delete_from_artifact_store,
            )
get_artifact(self, name, version=None)

Get the artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the artifact to retrieve.

required
version Optional[str]

The version of the artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the artifact linked to this model version.

    Args:
        name: The name of the artifact to retrieve.
        version: The version of the artifact to retrieve (None for
            latest/non-versioned)

    Returns:
        Specific version of the artifact or placeholder in the design time
            of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_artifact(
        name=name,
        version=version,
    )
get_data_artifact(self, name, version=None)

Get the data artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the data artifact to retrieve.

required
version Optional[str]

The version of the data artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the data artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_data_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the data artifact linked to this model version.

    Args:
        name: The name of the data artifact to retrieve.
        version: The version of the data artifact to retrieve (None for
            latest/non-versioned)

    Returns:
        Specific version of the data artifact or placeholder in the design
        time of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_data_artifact(
        name=name,
        version=version,
    )
get_deployment_artifact(self, name, version=None)

Get the deployment artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the deployment artifact to retrieve.

required
version Optional[str]

The version of the deployment artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the deployment artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_deployment_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the deployment artifact linked to this model version.

    Args:
        name: The name of the deployment artifact to retrieve.
        version: The version of the deployment artifact to retrieve (None
            for latest/non-versioned)

    Returns:
        Specific version of the deployment artifact or placeholder in the
        design time of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_deployment_artifact(
        name=name,
        version=version,
    )
get_model_artifact(self, name, version=None)

Get the model artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the model artifact to retrieve.

required
version Optional[str]

The version of the model artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Optional[ArtifactVersionResponse]

Specific version of the model artifact or placeholder in the design time of the pipeline.

Source code in zenml/model/model.py
def get_model_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional["ArtifactVersionResponse"]:
    """Get the model artifact linked to this model version.

    Args:
        name: The name of the model artifact to retrieve.
        version: The version of the model artifact to retrieve (None for
            latest/non-versioned)

    Returns:
        Specific version of the model artifact or placeholder in the design
            time of the pipeline.
    """
    if lazy := self._lazy_artifact_get(name, version):
        return lazy

    return self._get_or_create_model_version().get_model_artifact(
        name=name,
        version=version,
    )
get_pipeline_run(self, name)

Get pipeline run linked to this version.

Parameters:

Name Type Description Default
name str

The name of the pipeline run to retrieve.

required

Returns:

Type Description
PipelineRunResponse

PipelineRun as PipelineRunResponse

Source code in zenml/model/model.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
    """Get pipeline run linked to this version.

    Args:
        name: The name of the pipeline run to retrieve.

    Returns:
        PipelineRun as PipelineRunResponse
    """
    return self._get_or_create_model_version().get_pipeline_run(name=name)
load_artifact(self, name, version=None)

Load artifact from the Model Control Plane.

Parameters:

Name Type Description Default
name str

Name of the artifact to load.

required
version Optional[str]

Version of the artifact to load.

None

Returns:

Type Description
Any

The loaded artifact.

Exceptions:

Type Description
ValueError

if the model version is not linked to any artifact with the given name and version.

Source code in zenml/model/model.py
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
    """Load artifact from the Model Control Plane.

    Args:
        name: Name of the artifact to load.
        version: Version of the artifact to load.

    Returns:
        The loaded artifact.

    Raises:
        ValueError: if the model version is not linked to any artifact with
            the given name and version.
    """
    from zenml.artifacts.utils import load_artifact
    from zenml.models import ArtifactVersionResponse

    artifact = self.get_artifact(name=name, version=version)

    if not isinstance(artifact, ArtifactVersionResponse):
        raise ValueError(
            f"Version {self.version} of model {self.name} does not have "
            f"an artifact with name {name} and version {version}."
        )

    return load_artifact(artifact.id, str(artifact.version))
log_metadata(self, metadata)

Log model version metadata.

This function can be used to log metadata for current model version.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
Source code in zenml/model/model.py
def log_metadata(
    self,
    metadata: Dict[str, "MetadataType"],
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for current model version.

    Args:
        metadata: The metadata to log.
    """
    from zenml.client import Client

    response = self._get_or_create_model_version()
    Client().create_run_metadata(
        metadata=metadata,
        resource_id=response.id,
        resource_type=MetadataResourceTypes.MODEL_VERSION,
    )
model_post_init(/, self, context)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Parameters:

Name Type Description Default
self BaseModel

The BaseModel instance.

required
context Any

The context.

required
Source code in zenml/model/model.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
    """This function is meant to behave like a BaseModel method to initialise private attributes.

    It takes context as an argument since that's what pydantic-core passes when calling it.

    Args:
        self: The BaseModel instance.
        context: The context.
    """
    if getattr(self, '__pydantic_private__', None) is None:
        pydantic_private = {}
        for name, private_attr in self.__private_attributes__.items():
            default = private_attr.get_default()
            if default is not PydanticUndefined:
                pydantic_private[name] = default
        object_setattr(self, '__pydantic_private__', pydantic_private)
set_stage(self, stage, force=False)

Sets this Model to a desired stage.

Parameters:

Name Type Description Default
stage Union[str, zenml.enums.ModelStages]

the target stage for model version.

required
force bool

whether to force archiving of current model version in target stage or raise.

False
Source code in zenml/model/model.py
def set_stage(
    self, stage: Union[str, ModelStages], force: bool = False
) -> None:
    """Sets this Model to a desired stage.

    Args:
        stage: the target stage for model version.
        force: whether to force archiving of current model version in
            target stage or raise.
    """
    self._get_or_create_model_version().set_stage(stage=stage, force=force)

utils

Utility functions for linking step outputs to model versions.

Link the artifact to the model.

Parameters:

Name Type Description Default
artifact_version ArtifactVersionResponse

The artifact version to link.

required
model Optional[Model]

The model to link to.

None
is_model_artifact bool

Whether the artifact is a model artifact.

False
is_deployment_artifact bool

Whether the artifact is a deployment artifact.

False

Exceptions:

Type Description
RuntimeError

If called outside of a step.

Source code in zenml/model/utils.py
def link_artifact_to_model(
    artifact_version: ArtifactVersionResponse,
    model: Optional["Model"] = None,
    is_model_artifact: bool = False,
    is_deployment_artifact: bool = False,
) -> None:
    """Link the artifact to the model.

    Args:
        artifact_version: The artifact version to link.
        model: The model to link to.
        is_model_artifact: Whether the artifact is a model artifact.
        is_deployment_artifact: Whether the artifact is a deployment artifact.

    Raises:
        RuntimeError: If called outside of a step.
    """
    if not model:
        is_issue = False
        try:
            step_context = get_step_context()
            model = step_context.model
        except StepContextError:
            is_issue = True

        if model is None or is_issue:
            raise RuntimeError(
                "`link_artifact_to_model` called without `model` parameter "
                "and configured model context cannot be identified. Consider "
                "passing the `model` explicitly or configuring it in "
                "@step or @pipeline decorator."
            )

    model_version = model._get_or_create_model_version()
    artifact_config = ArtifactConfig(
        is_model_artifact=is_model_artifact,
        is_deployment_artifact=is_deployment_artifact,
    )
    link_artifact_version_to_model_version(
        artifact_version=artifact_version,
        model_version=model_version,
        artifact_config=artifact_config,
    )

Link an artifact version to a model version.

Parameters:

Name Type Description Default
artifact_version ArtifactVersionResponse

The artifact version to link.

required
model_version ModelVersionResponse

The model version to link.

required
artifact_config Optional[zenml.artifacts.artifact_config.ArtifactConfig]

Output artifact configuration.

None
Source code in zenml/model/utils.py
def link_artifact_version_to_model_version(
    artifact_version: ArtifactVersionResponse,
    model_version: ModelVersionResponse,
    artifact_config: Optional[ArtifactConfig] = None,
) -> None:
    """Link an artifact version to a model version.

    Args:
        artifact_version: The artifact version to link.
        model_version: The model version to link.
        artifact_config: Output artifact configuration.
    """
    if artifact_config:
        is_model_artifact = artifact_config.is_model_artifact
        is_deployment_artifact = artifact_config.is_deployment_artifact
    else:
        is_model_artifact = False
        is_deployment_artifact = False

    client = Client()
    client.zen_store.create_model_version_artifact_link(
        ModelVersionArtifactRequest(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            artifact_version=artifact_version.id,
            model=model_version.model.id,
            model_version=model_version.id,
            is_model_artifact=is_model_artifact,
            is_deployment_artifact=is_deployment_artifact,
        )
    )

Links a service to a model.

Parameters:

Name Type Description Default
service_id UUID

The ID of the service to link to the model.

required
model Optional[Model]

The model to link the service to.

None
model_version_id Optional[uuid.UUID]

The ID of the model version to link the service to.

None

Exceptions:

Type Description
RuntimeError

If no model is provided and the model context cannot be identified.

Source code in zenml/model/utils.py
def link_service_to_model(
    service_id: UUID,
    model: Optional["Model"] = None,
    model_version_id: Optional[UUID] = None,
) -> None:
    """Links a service to a model.

    Args:
        service_id: The ID of the service to link to the model.
        model: The model to link the service to.
        model_version_id: The ID of the model version to link the service to.

    Raises:
        RuntimeError: If no model is provided and the model context cannot be
            identified.
    """
    client = Client()

    # If no model is provided, try to get it from the context
    if not model and not model_version_id:
        is_issue = False
        try:
            step_context = get_step_context()
            model = step_context.model
        except StepContextError:
            is_issue = True

        if model is None or is_issue:
            raise RuntimeError(
                "`link_service_to_model` called without `model` parameter "
                "and configured model context cannot be identified. Consider "
                "passing the `model` explicitly or configuring it in "
                "@step or @pipeline decorator."
            )

    model_version_id = (
        model_version_id or model._get_or_create_model_version().id
        if model
        else None
    )
    update_service = ServiceUpdate(model_version_id=model_version_id)
    client.zen_store.update_service(
        service_id=service_id, update=update_service
    )

log_model_metadata(metadata, model_name=None, model_version=None)

Log model version metadata.

This function can be used to log metadata for existing model versions.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
model_name Optional[str]

The name of the model to log metadata for. Can be omitted when being called inside a step with configured model in decorator.

None
model_version Union[zenml.enums.ModelStages, int, str]

The version of the model to log metadata for. Can be omitted when being called inside a step with configured model in decorator.

None

Exceptions:

Type Description
ValueError

If no model name/version is provided and the function is not called inside a step with configured model in decorator.

Source code in zenml/model/utils.py
def log_model_metadata(
    metadata: Dict[str, "MetadataType"],
    model_name: Optional[str] = None,
    model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for existing model versions.

    Args:
        metadata: The metadata to log.
        model_name: The name of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model` in decorator.
        model_version: The version of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model` in decorator.

    Raises:
        ValueError: If no model name/version is provided and the function is not
            called inside a step with configured `model` in decorator.
    """
    if model_name and model_version:
        from zenml import Model

        mv = Model(name=model_name, version=model_version)
    else:
        try:
            step_context = get_step_context()
        except RuntimeError:
            raise ValueError(
                "Model name and version must be provided unless the function is "
                "called inside a step with configured `model` in decorator."
            )
        mv = step_context.model

    mv.log_metadata(metadata)