Skip to content

Model

zenml.model special

Concepts related to the Model Control Plane feature.

model_version

ModelVersion user facing interface to pass into pipeline or step.

ModelVersion (BaseModel) pydantic-model

ModelVersion class to pass into pipeline or step to set it into a model context.

name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage. If skipped new model version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack.

Source code in zenml/model/model_version.py
class ModelVersion(BaseModel):
    """ModelVersion class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    license: The license under which the model is created.
    description: The description of the model.
    audience: The target audience of the model.
    use_cases: The use cases of the model.
    limitations: The known limitations of the model.
    trade_offs: The tradeoffs of the model.
    ethics: The ethical implications of the model.
    tags: Tags associated with the model.
    version: The model version name, number or stage is optional and points model context
        to a specific version/stage. If skipped new model version will be created.
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    """

    name: str
    license: Optional[str] = None
    description: Optional[str] = None
    audience: Optional[str] = None
    use_cases: Optional[str] = None
    limitations: Optional[str] = None
    trade_offs: Optional[str] = None
    ethics: Optional[str] = None
    tags: Optional[List[str]] = None
    version: Optional[Union[ModelStages, int, str]] = None
    save_models_to_registry: bool = True

    suppress_class_validation_warnings: bool = False
    was_created_in_this_run: bool = False

    _model_id: UUID = PrivateAttr(None)
    _id: UUID = PrivateAttr(None)
    _number: int = PrivateAttr(None)

    #########################
    #    Public methods     #
    #########################
    @property
    def id(self) -> UUID:
        """Get version id from the Model Control Plane.

        Returns:
            ID of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        if self._id is None:
            try:
                self._get_or_create_model_version()
            except RuntimeError:
                logger.info(
                    f"Model version `{self.version}` doesn't exist "
                    "and cannot be fetched from the Model Control Plane."
                )
        return self._id

    @property
    def model_id(self) -> UUID:
        """Get model id from the Model Control Plane.

        Returns:
            The UUID of the model containing this model version.
        """
        if self._model_id is None:
            self._get_or_create_model()
        return self._model_id

    @property
    def number(self) -> int:
        """Get version number from  the Model Control Plane.

        Returns:
            Number of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        if self._number is None:
            try:
                self._get_or_create_model_version()
            except RuntimeError:
                logger.info(
                    f"Model version `{self.version}` doesn't exist "
                    "and cannot be fetched from the Model Control Plane."
                )
        return self._number

    @property
    def stage(self) -> Optional[ModelStages]:
        """Get version stage from  the Model Control Plane.

        Returns:
            Stage of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        try:
            stage = self._get_or_create_model_version().stage
            if stage:
                return ModelStages(stage)
        except RuntimeError:
            logger.info(
                f"Model version `{self.version}` doesn't exist "
                "and cannot be fetched from the Model Control Plane."
            )
        return None

    def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
        """Load artifact from the Model Control Plane.

        Args:
            name: Name of the artifact to load.
            version: Version of the artifact to load.

        Returns:
            The loaded artifact.

        Raises:
            ValueError: if the model version is not linked to any artifact with
                the given name and version.
        """
        from zenml.artifacts.utils import load_artifact
        from zenml.models import ArtifactVersionResponse

        artifact = self.get_artifact(name=name, version=version)

        if not isinstance(artifact, ArtifactVersionResponse):
            raise ValueError(
                f"Version {self.version} of model {self.name} does not have "
                f"an artifact with name {name} and version {version}."
            )

        return load_artifact(artifact.id, str(artifact.version))

    def _try_get_as_external_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional["ExternalArtifact"]:
        from zenml import ExternalArtifact, get_pipeline_context

        try:
            get_pipeline_context()
        except RuntimeError:
            return None

        ea = ExternalArtifact(name=name, version=version, model_version=self)
        return ea

    def get_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
        """Get the artifact linked to this model version.

        Args:
            name: The name of the artifact to retrieve.
            version: The version of the artifact to retrieve (None for latest/non-versioned)

        Returns:
            Inside pipeline context: ExternalArtifact object as a lazy loader
            Outside of pipeline context: Specific version of the artifact or None
        """
        if response := self._try_get_as_external_artifact(name, version):
            return response
        return self._get_or_create_model_version().get_artifact(
            name=name,
            version=version,
        )

    def get_model_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
        """Get the model artifact linked to this model version.

        Args:
            name: The name of the model artifact to retrieve.
            version: The version of the model artifact to retrieve (None for latest/non-versioned)

        Returns:
            Inside pipeline context: ExternalArtifact object as a lazy loader
            Outside of pipeline context: Specific version of the model artifact or None
        """
        if response := self._try_get_as_external_artifact(name, version):
            return response
        return self._get_or_create_model_version().get_model_artifact(
            name=name,
            version=version,
        )

    def get_data_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
        """Get the data artifact linked to this model version.

        Args:
            name: The name of the data artifact to retrieve.
            version: The version of the data artifact to retrieve (None for latest/non-versioned)

        Returns:
            Inside pipeline context: ExternalArtifact object as a lazy loader
            Outside of pipeline context: Specific version of the data artifact or None
        """
        if response := self._try_get_as_external_artifact(name, version):
            return response
        return self._get_or_create_model_version().get_data_artifact(
            name=name,
            version=version,
        )

    def get_deployment_artifact(
        self,
        name: str,
        version: Optional[str] = None,
    ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
        """Get the deployment artifact linked to this model version.

        Args:
            name: The name of the deployment artifact to retrieve.
            version: The version of the deployment artifact to retrieve (None for latest/non-versioned)

        Returns:
            Inside pipeline context: ExternalArtifact object as a lazy loader
            Outside of pipeline context: Specific version of the deployment artifact or None
        """
        if response := self._try_get_as_external_artifact(name, version):
            return response
        return self._get_or_create_model_version().get_deployment_artifact(
            name=name,
            version=version,
        )

    def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
        """Get pipeline run linked to this version.

        Args:
            name: The name of the pipeline run to retrieve.

        Returns:
            PipelineRun as PipelineRunResponse
        """
        return self._get_or_create_model_version().get_pipeline_run(name=name)

    def set_stage(
        self, stage: Union[str, ModelStages], force: bool = False
    ) -> None:
        """Sets this Model Version to a desired stage.

        Args:
            stage: the target stage for model version.
            force: whether to force archiving of current model version in target stage or raise.
        """
        self._get_or_create_model_version().set_stage(stage=stage, force=force)

    def log_metadata(
        self,
        metadata: Dict[str, "MetadataType"],
    ) -> None:
        """Log model version metadata.

        This function can be used to log metadata for current model version.

        Args:
            metadata: The metadata to log.
        """
        from zenml.client import Client

        response = self._get_or_create_model_version()
        Client().create_run_metadata(
            metadata=metadata,
            resource_id=response.id,
            resource_type=MetadataResourceTypes.MODEL_VERSION,
        )

    @property
    def metadata(self) -> Dict[str, "MetadataType"]:
        """Get model version metadata.

        Returns:
            The model version metadata.

        Raises:
            RuntimeError: If the model version metadata cannot be fetched.
        """
        response = self._get_or_create_model_version(hydrate=True)
        if response.run_metadata is None:
            raise RuntimeError(
                "Failed to fetch metadata of this model version."
            )
        return {
            name: response.value
            for name, response in response.run_metadata.items()
        }

    #########################
    #   Internal methods    #
    #########################

    class Config:
        """Config class."""

        smart_union = True

    def __eq__(self, other: object) -> bool:
        """Check two ModelVersions for equality.

        Args:
            other: object to compare with

        Returns:
            True, if equal, False otherwise.
        """
        if not isinstance(other, ModelVersion):
            return NotImplemented
        if self.name != other.name:
            return False
        if self.name == other.name and self.version == other.version:
            return True
        self_mv = self._get_or_create_model_version()
        other_mv = other._get_or_create_model_version()
        return self_mv.id == other_mv.id

    @root_validator(pre=True)
    def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            values: Dict of values.

        Returns:
            Dict of validated values.
        """
        suppress_class_validation_warnings = values.get(
            "suppress_class_validation_warnings", False
        )
        version = values.get("version", None)

        if (
            version in [stage.value for stage in ModelStages]
            and not suppress_class_validation_warnings
        ):
            logger.info(
                f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage."
            )
        if str(version).isnumeric() and not suppress_class_validation_warnings:
            logger.info(
                f"`version` `{version}` is numeric and will be fetched using version number."
            )
        values["suppress_class_validation_warnings"] = True
        return values

    def _validate_config_in_runtime(self) -> None:
        """Validate that config doesn't conflict with runtime environment."""
        self._get_or_create_model_version()

    def _get_or_create_model(self) -> "ModelResponse":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        from zenml.client import Client
        from zenml.models import ModelRequest

        zenml_client = Client()
        try:
            model = zenml_client.zen_store.get_model(
                model_name_or_id=self.name
            )

            difference = {}
            for key in (
                "license",
                "audience",
                "use_cases",
                "limitations",
                "trade_offs",
                "ethics",
                "save_models_to_registry",
            ):
                if getattr(self, key) != getattr(model, key):
                    difference[key] = {
                        "config": getattr(self, key),
                        "db": getattr(model, key),
                    }

            if difference:
                logger.warning(
                    "Provided model configuration does not match "
                    f"existing model `{self.name}` with the "
                    f"following changes: {difference}. If you want to "
                    "update the model configuration, please use the "
                    "`zenml model update` command."
                )
        except KeyError:
            model_request = ModelRequest(
                name=self.name,
                license=self.license,
                description=self.description,
                audience=self.audience,
                use_cases=self.use_cases,
                limitations=self.limitations,
                trade_offs=self.trade_offs,
                ethics=self.ethics,
                tags=self.tags,
                user=zenml_client.active_user.id,
                workspace=zenml_client.active_workspace.id,
                save_models_to_registry=self.save_models_to_registry,
            )
            model_request = ModelRequest.parse_obj(model_request)
            try:
                model = zenml_client.zen_store.create_model(
                    model=model_request
                )
                logger.info(f"New model `{self.name}` was created implicitly.")
            except EntityExistsError:
                # this is backup logic, if model was created somehow in between get and create calls
                pass
            finally:
                model = zenml_client.zen_store.get_model(
                    model_name_or_id=self.name
                )
        self._model_id = model.id
        return model

    def _get_model_version(self) -> "ModelVersionResponse":
        """This method gets a model version from Model Control Plane.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client

        zenml_client = Client()
        mv = zenml_client.get_model_version(
            model_name_or_id=self.name,
            model_version_name_or_number_or_id=self.version,
        )
        if not self._id:
            self._id = mv.id

        difference: Dict[str, Any] = {}
        if mv.description != self.description:
            difference["description"] = {
                "config": self.description,
                "db": mv.description,
            }
        configured_tags = set(self.tags or [])
        db_tags = {t.name for t in mv.tags}
        if db_tags != configured_tags:
            difference["tags added"] = list(configured_tags - db_tags)
            difference["tags removed"] = list(db_tags - configured_tags)
        if difference:
            logger.warning(
                "Provided model version configuration does not match existing model "
                f"version `{self.name}::{self.version}` with the following "
                f"changes: {difference}. If you want to update the model version "
                "configuration, please use the `zenml model version update` command."
            )

        return mv

    def _get_or_create_model_version(
        self, hydrate: bool = False
    ) -> "ModelVersionResponse":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model is fetched. Model
        name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model version:
        - If `version` is None, a new model version is created, if not created by other steps in same run.
        - If `version` is not None a model version will be fetched based on the version:
            - If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
            - If `version` is set to a string, the model version with the matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

        Args:
            hydrate: Whether to return a hydrated version of the model version.

        Returns:
            The model version based on configuration.

        Raises:
            RuntimeError: if the model version needs to be created, but provided name is reserved
        """
        from zenml.client import Client
        from zenml.models import ModelVersionRequest

        model = self._get_or_create_model()

        zenml_client = Client()
        model_version_request = ModelVersionRequest(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=self.version,
            description=self.description,
            model=model.id,
            tags=self.tags,
        )
        mv_request = ModelVersionRequest.parse_obj(model_version_request)
        try:
            if not self.version:
                try:
                    from zenml import get_step_context

                    context = get_step_context()
                except RuntimeError:
                    pass
                else:
                    # if inside a step context we loop over all
                    # model version configuration to find, if the
                    # model version for current model was already
                    # created in the current run, not to create
                    # new model versions
                    pipeline_mv = context.pipeline_run.config.model_version
                    if (
                        pipeline_mv
                        and pipeline_mv.was_created_in_this_run
                        and pipeline_mv.name == self.name
                        and pipeline_mv.version is not None
                    ):
                        self.version = pipeline_mv.version
                    else:
                        for step in context.pipeline_run.steps.values():
                            step_mv = step.config.model_version
                            if (
                                step_mv
                                and step_mv.was_created_in_this_run
                                and step_mv.name == self.name
                                and step_mv.version is not None
                            ):
                                self.version = step_mv.version
                                break
            if self.version:
                model_version = self._get_model_version()
            else:
                raise KeyError
        except KeyError:
            if (
                self.version
                and str(self.version).lower() in ModelStages.values()
            ):
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "it matches one of the possible model version stages. If you "
                    "are aiming to fetch model version by stage, check if the "
                    "model version in given stage exists. It might be missing, if "
                    "the pipeline promoting model version to this stage failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list {self.name}` CLI command."
                )
            if str(self.version).isnumeric():
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "numeric model version names are reserved. If you "
                    "are aiming to fetch model version by number, check if the "
                    "model version with given number exists. It might be missing, if "
                    "the pipeline creating model version failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list {self.name}` CLI command."
                )
            model_version = zenml_client.zen_store.create_model_version(
                model_version=mv_request
            )
            self.version = model_version.name
            self.was_created_in_this_run = True
            logger.info(f"New model version `{self.version}` was created.")
        self._id = model_version.id
        self._model_id = model_version.model.id
        self._number = model_version.number
        return model_version

    def _merge(self, model_version: "ModelVersion") -> None:
        self.license = self.license or model_version.license
        self.description = self.description or model_version.description
        self.audience = self.audience or model_version.audience
        self.use_cases = self.use_cases or model_version.use_cases
        self.limitations = self.limitations or model_version.limitations
        self.trade_offs = self.trade_offs or model_version.trade_offs
        self.ethics = self.ethics or model_version.ethics
        if model_version.tags is not None:
            self.tags = list(
                {t for t in self.tags or []}.union(set(model_version.tags))
            )

    def __hash__(self) -> int:
        """Get hash of the `ModelVersion`.

        Returns:
            Hash function results
        """
        return hash(
            "::".join(
                (
                    str(v)
                    for v in (
                        self.name,
                        self.version,
                    )
                )
            )
        )
id: UUID property readonly

Get version id from the Model Control Plane.

Returns:

Type Description
UUID

ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

metadata: Dict[str, MetadataType] property readonly

Get model version metadata.

Returns:

Type Description
Dict[str, MetadataType]

The model version metadata.

Exceptions:

Type Description
RuntimeError

If the model version metadata cannot be fetched.

model_id: UUID property readonly

Get model id from the Model Control Plane.

Returns:

Type Description
UUID

The UUID of the model containing this model version.

number: int property readonly

Get version number from the Model Control Plane.

Returns:

Type Description
int

Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

stage: Optional[zenml.enums.ModelStages] property readonly

Get version stage from the Model Control Plane.

Returns:

Type Description
Optional[zenml.enums.ModelStages]

Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

Config

Config class.

Source code in zenml/model/model_version.py
class Config:
    """Config class."""

    smart_union = True
__eq__(self, other) special

Check two ModelVersions for equality.

Parameters:

Name Type Description Default
other object

object to compare with

required

Returns:

Type Description
bool

True, if equal, False otherwise.

Source code in zenml/model/model_version.py
def __eq__(self, other: object) -> bool:
    """Check two ModelVersions for equality.

    Args:
        other: object to compare with

    Returns:
        True, if equal, False otherwise.
    """
    if not isinstance(other, ModelVersion):
        return NotImplemented
    if self.name != other.name:
        return False
    if self.name == other.name and self.version == other.version:
        return True
    self_mv = self._get_or_create_model_version()
    other_mv = other._get_or_create_model_version()
    return self_mv.id == other_mv.id
__hash__(self) special

Get hash of the ModelVersion.

Returns:

Type Description
int

Hash function results

Source code in zenml/model/model_version.py
def __hash__(self) -> int:
    """Get hash of the `ModelVersion`.

    Returns:
        Hash function results
    """
    return hash(
        "::".join(
            (
                str(v)
                for v in (
                    self.name,
                    self.version,
                )
            )
        )
    )
get_artifact(self, name, version=None)

Get the artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the artifact to retrieve.

required
version Optional[str]

The version of the artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Inside pipeline context

ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the artifact or None

Source code in zenml/model/model_version.py
def get_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
    """Get the artifact linked to this model version.

    Args:
        name: The name of the artifact to retrieve.
        version: The version of the artifact to retrieve (None for latest/non-versioned)

    Returns:
        Inside pipeline context: ExternalArtifact object as a lazy loader
        Outside of pipeline context: Specific version of the artifact or None
    """
    if response := self._try_get_as_external_artifact(name, version):
        return response
    return self._get_or_create_model_version().get_artifact(
        name=name,
        version=version,
    )
get_data_artifact(self, name, version=None)

Get the data artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the data artifact to retrieve.

required
version Optional[str]

The version of the data artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Inside pipeline context

ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the data artifact or None

Source code in zenml/model/model_version.py
def get_data_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
    """Get the data artifact linked to this model version.

    Args:
        name: The name of the data artifact to retrieve.
        version: The version of the data artifact to retrieve (None for latest/non-versioned)

    Returns:
        Inside pipeline context: ExternalArtifact object as a lazy loader
        Outside of pipeline context: Specific version of the data artifact or None
    """
    if response := self._try_get_as_external_artifact(name, version):
        return response
    return self._get_or_create_model_version().get_data_artifact(
        name=name,
        version=version,
    )
get_deployment_artifact(self, name, version=None)

Get the deployment artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the deployment artifact to retrieve.

required
version Optional[str]

The version of the deployment artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Inside pipeline context

ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the deployment artifact or None

Source code in zenml/model/model_version.py
def get_deployment_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
    """Get the deployment artifact linked to this model version.

    Args:
        name: The name of the deployment artifact to retrieve.
        version: The version of the deployment artifact to retrieve (None for latest/non-versioned)

    Returns:
        Inside pipeline context: ExternalArtifact object as a lazy loader
        Outside of pipeline context: Specific version of the deployment artifact or None
    """
    if response := self._try_get_as_external_artifact(name, version):
        return response
    return self._get_or_create_model_version().get_deployment_artifact(
        name=name,
        version=version,
    )
get_model_artifact(self, name, version=None)

Get the model artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the model artifact to retrieve.

required
version Optional[str]

The version of the model artifact to retrieve (None for latest/non-versioned)

None

Returns:

Type Description
Inside pipeline context

ExternalArtifact object as a lazy loader Outside of pipeline context: Specific version of the model artifact or None

Source code in zenml/model/model_version.py
def get_model_artifact(
    self,
    name: str,
    version: Optional[str] = None,
) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]:
    """Get the model artifact linked to this model version.

    Args:
        name: The name of the model artifact to retrieve.
        version: The version of the model artifact to retrieve (None for latest/non-versioned)

    Returns:
        Inside pipeline context: ExternalArtifact object as a lazy loader
        Outside of pipeline context: Specific version of the model artifact or None
    """
    if response := self._try_get_as_external_artifact(name, version):
        return response
    return self._get_or_create_model_version().get_model_artifact(
        name=name,
        version=version,
    )
get_pipeline_run(self, name)

Get pipeline run linked to this version.

Parameters:

Name Type Description Default
name str

The name of the pipeline run to retrieve.

required

Returns:

Type Description
PipelineRunResponse

PipelineRun as PipelineRunResponse

Source code in zenml/model/model_version.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
    """Get pipeline run linked to this version.

    Args:
        name: The name of the pipeline run to retrieve.

    Returns:
        PipelineRun as PipelineRunResponse
    """
    return self._get_or_create_model_version().get_pipeline_run(name=name)
load_artifact(self, name, version=None)

Load artifact from the Model Control Plane.

Parameters:

Name Type Description Default
name str

Name of the artifact to load.

required
version Optional[str]

Version of the artifact to load.

None

Returns:

Type Description
Any

The loaded artifact.

Exceptions:

Type Description
ValueError

if the model version is not linked to any artifact with the given name and version.

Source code in zenml/model/model_version.py
def load_artifact(self, name: str, version: Optional[str] = None) -> Any:
    """Load artifact from the Model Control Plane.

    Args:
        name: Name of the artifact to load.
        version: Version of the artifact to load.

    Returns:
        The loaded artifact.

    Raises:
        ValueError: if the model version is not linked to any artifact with
            the given name and version.
    """
    from zenml.artifacts.utils import load_artifact
    from zenml.models import ArtifactVersionResponse

    artifact = self.get_artifact(name=name, version=version)

    if not isinstance(artifact, ArtifactVersionResponse):
        raise ValueError(
            f"Version {self.version} of model {self.name} does not have "
            f"an artifact with name {name} and version {version}."
        )

    return load_artifact(artifact.id, str(artifact.version))
log_metadata(self, metadata)

Log model version metadata.

This function can be used to log metadata for current model version.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
Source code in zenml/model/model_version.py
def log_metadata(
    self,
    metadata: Dict[str, "MetadataType"],
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for current model version.

    Args:
        metadata: The metadata to log.
    """
    from zenml.client import Client

    response = self._get_or_create_model_version()
    Client().create_run_metadata(
        metadata=metadata,
        resource_id=response.id,
        resource_type=MetadataResourceTypes.MODEL_VERSION,
    )
set_stage(self, stage, force=False)

Sets this Model Version to a desired stage.

Parameters:

Name Type Description Default
stage Union[str, zenml.enums.ModelStages]

the target stage for model version.

required
force bool

whether to force archiving of current model version in target stage or raise.

False
Source code in zenml/model/model_version.py
def set_stage(
    self, stage: Union[str, ModelStages], force: bool = False
) -> None:
    """Sets this Model Version to a desired stage.

    Args:
        stage: the target stage for model version.
        force: whether to force archiving of current model version in target stage or raise.
    """
    self._get_or_create_model_version().set_stage(stage=stage, force=force)

utils

Utility functions for linking step outputs to model versions.

Link an artifact config to its model version.

Parameters:

Name Type Description Default
artifact_config ArtifactConfig

The artifact config to link.

required
artifact_version_id UUID

The ID of the artifact to link.

required
model_version Optional[ModelVersion]

The model version from the step or pipeline context.

None
Source code in zenml/model/utils.py
def link_artifact_config_to_model_version(
    artifact_config: ArtifactConfig,
    artifact_version_id: UUID,
    model_version: Optional["ModelVersion"] = None,
) -> None:
    """Link an artifact config to its model version.

    Args:
        artifact_config: The artifact config to link.
        artifact_version_id: The ID of the artifact to link.
        model_version: The model version from the step or pipeline context.
    """
    client = Client()

    # If the artifact config specifies a model itself then always use that
    if artifact_config.model_name is not None:
        from zenml.model.model_version import ModelVersion

        model_version = ModelVersion(
            name=artifact_config.model_name,
            version=artifact_config.model_version,
        )

    if model_version:
        model_version._get_or_create_model_version()
        model_version_response = model_version._get_model_version()
        request = ModelVersionArtifactRequest(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            artifact_version=artifact_version_id,
            model=model_version_response.model.id,
            model_version=model_version_response.id,
            is_model_artifact=artifact_config.is_model_artifact,
            is_deployment_artifact=artifact_config.is_deployment_artifact,
        )
        client.zen_store.create_model_version_artifact_link(request)

Links the output artifacts of a step to the model.

Parameters:

Name Type Description Default
artifact_version_ids Dict[str, uuid.UUID]

The IDs of the published output artifacts.

required

Exceptions:

Type Description
RuntimeError

If called outside of a step.

Source code in zenml/model/utils.py
def link_step_artifacts_to_model(
    artifact_version_ids: Dict[str, UUID],
) -> None:
    """Links the output artifacts of a step to the model.

    Args:
        artifact_version_ids: The IDs of the published output artifacts.

    Raises:
        RuntimeError: If called outside of a step.
    """
    try:
        step_context = get_step_context()
    except StepContextError:
        raise RuntimeError(
            "`link_step_artifacts_to_model` can only be called from within a "
            "step."
        )
    try:
        model_version = step_context.model_version
    except StepContextError:
        model_version = None
        logger.debug("No model context found, unable to auto-link artifacts.")

    for artifact_name, artifact_version_id in artifact_version_ids.items():
        artifact_config = step_context._get_output(
            artifact_name
        ).artifact_config

        # Implicit linking
        if artifact_config is None and model_version is not None:
            artifact_config = ArtifactConfig(name=artifact_name)
            logger.info(
                f"Implicitly linking artifact `{artifact_name}` to model "
                f"`{model_version.name}` version `{model_version.version}`."
            )

        if artifact_config:
            link_artifact_config_to_model_version(
                artifact_config=artifact_config,
                artifact_version_id=artifact_version_id,
                model_version=model_version,
            )

log_model_metadata(metadata, model_name=None, model_version=None)

Log model version metadata.

This function can be used to log metadata for existing model versions.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
model_name Optional[str]

The name of the model to log metadata for. Can be omitted when being called inside a step with configured model_version in decorator.

None
model_version Union[zenml.enums.ModelStages, str, int]

The version of the model to log metadata for. Can be omitted when being called inside a step with configured model_version in decorator.

None

Exceptions:

Type Description
ValueError

If no model name/version is provided and the function is not called inside a step with configured model_version in decorator.

Source code in zenml/model/utils.py
def log_model_metadata(
    metadata: Dict[str, "MetadataType"],
    model_name: Optional[str] = None,
    model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for existing model versions.

    Args:
        metadata: The metadata to log.
        model_name: The name of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model_version` in decorator.
        model_version: The version of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model_version` in decorator.

    Raises:
        ValueError: If no model name/version is provided and the function is not
            called inside a step with configured `model_version` in decorator.
    """
    mv = None
    try:
        step_context = get_step_context()
        mv = step_context.model_version
    except RuntimeError:
        step_context = None

    if not step_context and not (model_name and model_version):
        raise ValueError(
            "Model name and version must be provided unless the function is "
            "called inside a step with configured `model_version` in decorator."
        )
    if mv is None:
        from zenml import ModelVersion

        mv = ModelVersion(name=model_name, version=model_version)

    mv.log_metadata(metadata)

log_model_version_metadata(metadata, model_name=None, model_version=None)

Log model version metadata.

This function can be used to log metadata for existing model versions.

Parameters:

Name Type Description Default
metadata Dict[str, MetadataType]

The metadata to log.

required
model_name Optional[str]

The name of the model to log metadata for. Can be omitted when being called inside a step with configured model_version in decorator.

None
model_version Union[zenml.enums.ModelStages, str, int]

The version of the model to log metadata for. Can be omitted when being called inside a step with configured model_version in decorator.

None
Source code in zenml/model/utils.py
def log_model_version_metadata(
    metadata: Dict[str, "MetadataType"],
    model_name: Optional[str] = None,
    model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
    """Log model version metadata.

    This function can be used to log metadata for existing model versions.

    Args:
        metadata: The metadata to log.
        model_name: The name of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model_version` in decorator.
        model_version: The version of the model to log metadata for. Can
            be omitted when being called inside a step with configured
            `model_version` in decorator.
    """
    logger.warning(
        "`log_model_version_metadata` is deprecated. Please use "
        "`log_model_metadata` instead."
    )
    log_model_metadata(
        metadata=metadata, model_name=model_name, model_version=model_version
    )