Skip to content

Model

zenml.model special

Initialization of ZenML model. ZenML model support Model Control Plane feature.

artifact_config

Artifact Config classes to support Model Control Plane feature.

DataArtifactConfig (BaseModel) pydantic-model

Used to link a data artifact to the model version.

model_name: The name of the model to link data artifact to. !!! model_version "The identifier of the model version to link data artifact to." It can be exact version ("23"), exact version number (42), stage (ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions.

Source code in zenml/model/artifact_config.py
class DataArtifactConfig(BaseModel):
    """Used to link a data artifact to the model version.

    model_name: The name of the model to link data artifact to.
    model_version: The identifier of the model version to link data artifact to.
        It can be exact version ("23"), exact version number (42), stage
        (ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version.
    model_stage: The stage of the model version to link artifact to.
    artifact_name: The override name of a link instead of an artifact name.
    overwrite: Whether to overwrite an existing link or create new versions.
    """

    model_name: Optional[str]
    model_version: Optional[Union[ModelStages, str, int]]
    artifact_name: Optional[str]
    overwrite: bool = False

    _pipeline_name: str = PrivateAttr()
    _step_name: str = PrivateAttr()
    IS_MODEL_ARTIFACT: ClassVar[bool] = False
    IS_ENDPOINT_ARTIFACT: ClassVar[bool] = False

    @root_validator
    def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        model_name = values.get("model_name", None)
        if model_name and values.get("model_version", None) is None:
            raise ValueError(
                f"Creation of new model version from `{cls}` is not allowed. "
                "Please either keep `model_name` and `model_version` both "
                "`None` to get the model version from the step context or "
                "specify both at the same time. You can use `ModelStages.LATEST` "
                "as `model_version` when latest model version is desired."
            )
        return values

    class Config:
        """Config class for ArtifactConfig."""

        smart_union = True

    @property
    def _model_version(self) -> "ModelVersion":
        """Property that returns the model version.

        Returns:
            ModelVersion: The model version.

        Raises:
            RuntimeError: If model version cannot be acquired from @step
                or @pipeline or built on the fly from fields of this class.
        """
        try:
            model_version = get_step_context().model_version
        except (StepContextError, RuntimeError):
            model_version = None
        # Check if a specific model name is provided and it doesn't match the context name
        if (self.model_name is not None) and (
            model_version is None or model_version.name != self.model_name
        ):
            # Create a new ModelVersion instance with the provided model name and version
            from zenml.model.model_version import ModelVersion

            on_the_fly_config = ModelVersion(
                name=self.model_name,
                version=self.model_version,
            )
            return on_the_fly_config

        if model_version is None:
            raise RuntimeError(
                "No model version configuration found in @step or @pipeline. "
                "You can configure model version inside ArtifactConfig as well, but "
                "`model_name` and `model_version` must be provided."
            )
        # Return the model from the context
        return model_version

    def _link_to_model_version(
        self,
        artifact_uuid: UUID,
        model_version: "ModelVersion",
        is_model_artifact: bool = False,
        is_endpoint_artifact: bool = False,
    ) -> None:
        """Link artifact to the model version.

        This method is used on exit from the step context to link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            model_version: The model version from caller.
            is_model_artifact: Whether the artifact is a model artifact. Defaults to False.
            is_endpoint_artifact: Whether the artifact is an endpoint artifact. Defaults to False.
        """
        from zenml.client import Client
        from zenml.models.model_models import (
            ModelVersionArtifactFilterModel,
            ModelVersionArtifactRequestModel,
        )

        # Create a ZenML client
        client = Client()

        artifact_name = self.artifact_name
        if artifact_name is None:
            artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid)
            artifact_name = artifact.name

        # Create a request model for the model version artifact link
        request = ModelVersionArtifactRequestModel(
            user=client.active_user.id,
            workspace=client.active_workspace.id,
            name=artifact_name,
            artifact=artifact_uuid,
            model=model_version.model_id,
            model_version=model_version.id,
            is_model_artifact=is_model_artifact,
            is_endpoint_artifact=is_endpoint_artifact,
            overwrite=self.overwrite,
            pipeline_name=self._pipeline_name,
            step_name=self._step_name,
        )

        # Create the model version artifact link using the ZenML client
        existing_links = client.list_model_version_artifact_links(
            model_version_id=model_version.id,
            model_version_artifact_link_filter_model=ModelVersionArtifactFilterModel(
                user_id=client.active_user.id,
                workspace_id=client.active_workspace.id,
                name=artifact_name,
                only_data_artifacts=not (
                    is_model_artifact or is_endpoint_artifact
                ),
                only_endpoint_artifacts=is_endpoint_artifact,
                only_model_artifacts=is_model_artifact,
                pipeline_name=self._pipeline_name,
                step_name=self._step_name,
            ),
        )
        if len(existing_links):
            if self.overwrite:
                # delete all model version artifact links by name
                logger.warning(
                    f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
                )

                client.zen_store.delete_model_version_artifact_link(
                    model_version_id=model_version.id,
                    model_version_artifact_link_name_or_id=artifact_name,
                )
            else:
                logger.info(
                    f"Artifact link `{artifact_name}` already exists, adding new version."
                )
        client.zen_store.create_model_version_artifact_link(request)

    def link_to_model(
        self, artifact_uuid: UUID, model_version: "ModelVersion"
    ) -> None:
        """Link artifact to the model version.

        Args:
            artifact_uuid: The UUID of the artifact to link.
            model_version: The model version from caller.
        """
        self._link_to_model_version(
            artifact_uuid,
            model_version=model_version,
            is_model_artifact=self.IS_MODEL_ARTIFACT,
            is_endpoint_artifact=self.IS_ENDPOINT_ARTIFACT,
        )
Config

Config class for ArtifactConfig.

Source code in zenml/model/artifact_config.py
class Config:
    """Config class for ArtifactConfig."""

    smart_union = True

Link artifact to the model version.

Parameters:

Name Type Description Default
artifact_uuid UUID

The UUID of the artifact to link.

required
model_version ModelVersion

The model version from caller.

required
Source code in zenml/model/artifact_config.py
def link_to_model(
    self, artifact_uuid: UUID, model_version: "ModelVersion"
) -> None:
    """Link artifact to the model version.

    Args:
        artifact_uuid: The UUID of the artifact to link.
        model_version: The model version from caller.
    """
    self._link_to_model_version(
        artifact_uuid,
        model_version=model_version,
        is_model_artifact=self.IS_MODEL_ARTIFACT,
        is_endpoint_artifact=self.IS_ENDPOINT_ARTIFACT,
    )

EndpointArtifactConfig (DataArtifactConfig) pydantic-model

Used to link an endpoint artifact to the model version.

Source code in zenml/model/artifact_config.py
class EndpointArtifactConfig(DataArtifactConfig):
    """Used to link an endpoint artifact to the model version."""

    IS_ENDPOINT_ARTIFACT = True

ModelArtifactConfig (DataArtifactConfig) pydantic-model

Used to link a model artifact to the model version.

save_to_model_registry: Whether to save the model artifact to the model registry.

Source code in zenml/model/artifact_config.py
class ModelArtifactConfig(DataArtifactConfig):
    """Used to link a model artifact to the model version.

    save_to_model_registry: Whether to save the model artifact to the model registry.
    """

    save_to_model_registry: bool = True
    IS_MODEL_ARTIFACT = True

Utility functions for linking step outputs to model versions.

Log artifact metadata.

Parameters:

Name Type Description Default
output_name Optional[str]

The output name of the artifact to log metadata for. Can be omitted if there is only one output artifact.

None
artifact_config DataArtifactConfig

The ArtifactConfig of how to link this output.

required
Source code in zenml/model/link_output_to_model.py
def link_output_to_model(
    artifact_config: "DataArtifactConfig",
    output_name: Optional[str] = None,
) -> None:
    """Log artifact metadata.

    Args:
        output_name: The output name of the artifact to log metadata for. Can
            be omitted if there is only one output artifact.
        artifact_config: The ArtifactConfig of how to link this output.
    """
    from zenml.new.steps.step_context import get_step_context

    step_context = get_step_context()
    step_context._set_artifact_config(
        output_name=output_name, artifact_config=artifact_config
    )

model_version

ModelVersion user facing interface to pass into pipeline or step.

ModelVersion (BaseModel) pydantic-model

ModelVersion class to pass into pipeline or step to set it into a model context.

name: The name of the model. license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. ethics: The ethical implications of the model. tags: Tags associated with the model. !!! version "The model version name, number or stage is optional and points model context" to a specific version/stage. If skipped new model version will be created. !!! save_models_to_registry "Whether to save all ModelArtifacts to Model Registry," if available in active stack.

Source code in zenml/model/model_version.py
class ModelVersion(BaseModel):
    """ModelVersion class to pass into pipeline or step to set it into a model context.

    name: The name of the model.
    license: The license under which the model is created.
    description: The description of the model.
    audience: The target audience of the model.
    use_cases: The use cases of the model.
    limitations: The known limitations of the model.
    trade_offs: The tradeoffs of the model.
    ethics: The ethical implications of the model.
    tags: Tags associated with the model.
    version: The model version name, number or stage is optional and points model context
        to a specific version/stage. If skipped new model version will be created.
    save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
        if available in active stack.
    """

    name: str
    license: Optional[str]
    description: Optional[str]
    audience: Optional[str]
    use_cases: Optional[str]
    limitations: Optional[str]
    trade_offs: Optional[str]
    ethics: Optional[str]
    tags: Optional[List[str]]
    version: Optional[Union[ModelStages, int, str]]
    save_models_to_registry: bool = True

    suppress_class_validation_warnings: bool = False
    was_created_in_this_run: bool = False

    _model_id: UUID = PrivateAttr(None)
    _id: UUID = PrivateAttr(None)
    _number: int = PrivateAttr(None)

    #########################
    #    Public methods     #
    #########################
    @property
    def id(self) -> UUID:
        """Get version id from the Model Control Plane.

        Returns:
            ID of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        if self._id is None:
            try:
                self._get_or_create_model_version()
            except RuntimeError:
                logger.info(
                    f"Model version `{self.version}` doesn't exist "
                    "and cannot be fetched from the Model Control Plane."
                )
        return self._id

    @property
    def model_id(self) -> UUID:
        """Get model id from the Model Control Plane.

        Returns:
            The UUID of the model containing this model version.
        """
        if self._model_id is None:
            self._get_or_create_model()
        return self._model_id

    @property
    def number(self) -> int:
        """Get version number from  the Model Control Plane.

        Returns:
            Number of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        if self._number is None:
            try:
                self._get_or_create_model_version()
            except RuntimeError:
                logger.info(
                    f"Model version `{self.version}` doesn't exist "
                    "and cannot be fetched from the Model Control Plane."
                )
        return self._number

    @property
    def stage(self) -> Optional[ModelStages]:
        """Get version stage from  the Model Control Plane.

        Returns:
            Stage of the model version or None, if model version
                doesn't exist and can only be read given current
                config (you used stage name or number as
                a version name).
        """
        try:
            stage = self._get_or_create_model_version().stage
            if stage:
                return ModelStages(stage)
        except RuntimeError:
            logger.info(
                f"Model version `{self.version}` doesn't exist "
                "and cannot be fetched from the Model Control Plane."
            )
        return None

    def get_model_artifact(
        self,
        name: str,
        version: Optional[str] = None,
        pipeline_name: Optional[str] = None,
        step_name: Optional[str] = None,
    ) -> Optional["ArtifactResponse"]:
        """Get the model artifact linked to this model version.

        Args:
            name: The name of the model artifact to retrieve.
            version: The version of the model artifact to retrieve (None for latest/non-versioned)
            pipeline_name: The name of the pipeline-generated the model artifact.
            step_name: The name of the step-generated the model artifact.

        Returns:
            Specific version of the model artifact or None
        """
        return self._get_or_create_model_version().get_model_artifact(
            name=name,
            version=version,
            pipeline_name=pipeline_name,
            step_name=step_name,
        )

    def get_data_artifact(
        self,
        name: str,
        version: Optional[str] = None,
        pipeline_name: Optional[str] = None,
        step_name: Optional[str] = None,
    ) -> Optional["ArtifactResponse"]:
        """Get the data artifact linked to this model version.

        Args:
            name: The name of the data artifact to retrieve.
            version: The version of the data artifact to retrieve (None for latest/non-versioned)
            pipeline_name: The name of the pipeline generated the data artifact.
            step_name: The name of the step generated the data artifact.

        Returns:
            Specific version of the data artifact or None
        """
        return self._get_or_create_model_version().get_data_artifact(
            name=name,
            version=version,
            pipeline_name=pipeline_name,
            step_name=step_name,
        )

    def get_endpoint_artifact(
        self,
        name: str,
        version: Optional[str] = None,
        pipeline_name: Optional[str] = None,
        step_name: Optional[str] = None,
    ) -> Optional["ArtifactResponse"]:
        """Get the endpoint artifact linked to this model version.

        Args:
            name: The name of the endpoint artifact to retrieve.
            version: The version of the endpoint artifact to retrieve (None for latest/non-versioned)
            pipeline_name: The name of the pipeline generated the endpoint artifact.
            step_name: The name of the step generated the endpoint artifact.

        Returns:
            Specific version of the endpoint artifact or None
        """
        return self._get_or_create_model_version().get_endpoint_artifact(
            name=name,
            version=version,
            pipeline_name=pipeline_name,
            step_name=step_name,
        )

    def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
        """Get pipeline run linked to this version.

        Args:
            name: The name of the pipeline run to retrieve.

        Returns:
            PipelineRun as PipelineRunResponse
        """
        return self._get_or_create_model_version().get_pipeline_run(name=name)

    def set_stage(
        self, stage: Union[str, ModelStages], force: bool = False
    ) -> "ModelVersion":
        """Sets this Model Version to a desired stage.

        Args:
            stage: the target stage for model version.
            force: whether to force archiving of current model version in target stage or raise.

        Returns:
            Updated Model Version object.
        """
        return self._get_or_create_model_version().set_stage(
            stage=stage, force=force
        )

    #########################
    #   Internal methods    #
    #########################

    class Config:
        """Config class."""

        smart_union = True

    def __eq__(self, other: object) -> bool:
        """Check two ModelVersions for equality.

        Args:
            other: object to compare with

        Returns:
            True, if equal, False otherwise.
        """
        if not isinstance(other, ModelVersion):
            return NotImplemented
        if self.name != other.name:
            return False
        if self.name == other.name and self.version == other.version:
            return True
        self_mv = self._get_or_create_model_version()
        other_mv = other._get_or_create_model_version()
        return self_mv.id == other_mv.id

    @root_validator(pre=True)
    def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate all in one.

        Args:
            values: Dict of values.

        Returns:
            Dict of validated values.
        """
        suppress_class_validation_warnings = values.get(
            "suppress_class_validation_warnings", False
        )
        version = values.get("version", None)

        if (
            version in [stage.value for stage in ModelStages]
            and not suppress_class_validation_warnings
        ):
            logger.info(
                f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage."
            )
        if str(version).isnumeric() and not suppress_class_validation_warnings:
            logger.info(
                f"`version` `{version}` is numeric and will be fetched using version number."
            )
        values["suppress_class_validation_warnings"] = True
        return values

    def _validate_config_in_runtime(self) -> None:
        """Validate that config doesn't conflict with runtime environment."""
        self._get_or_create_model_version()

    def _get_or_create_model(self) -> "ModelResponseModel":
        """This method should get or create a model from Model Control Plane.

        New model is created implicitly, if missing, otherwise fetched.

        Returns:
            The model based on configuration.
        """
        from zenml.client import Client
        from zenml.models.model_models import ModelRequestModel

        zenml_client = Client()
        try:
            model = zenml_client.zen_store.get_model(
                model_name_or_id=self.name
            )
        except KeyError:
            model_request = ModelRequestModel(
                name=self.name,
                license=self.license,
                description=self.description,
                audience=self.audience,
                use_cases=self.use_cases,
                limitations=self.limitations,
                trade_offs=self.trade_offs,
                ethics=self.ethics,
                tags=self.tags,
                user=zenml_client.active_user.id,
                workspace=zenml_client.active_workspace.id,
            )
            model_request = ModelRequestModel.parse_obj(model_request)
            try:
                model = zenml_client.zen_store.create_model(
                    model=model_request
                )
                logger.info(f"New model `{self.name}` was created implicitly.")
            except EntityExistsError:
                # this is backup logic, if model was created somehow in between get and create calls
                pass
            finally:
                model = zenml_client.zen_store.get_model(
                    model_name_or_id=self.name
                )
        self._model_id = model.id
        return model

    def _get_model_version(self) -> "ModelVersionResponseModel":
        """This method gets a model version from Model Control Plane.

        Returns:
            The model version based on configuration.
        """
        from zenml.client import Client

        zenml_client = Client()
        mv = zenml_client._get_model_version(
            model_name_or_id=self.name,
            model_version_name_or_number_or_id=self.version,
        )
        if not self._id:
            self._id = mv.id

        return mv

    def _get_or_create_model_version(self) -> "ModelVersionResponseModel":
        """This method should get or create a model and a model version from Model Control Plane.

        A new model is created implicitly if missing, otherwise existing model is fetched. Model
        name is controlled by the `name` parameter.

        Model Version returned by this method is resolved based on model version:
        - If `version` is None, a new model version is created, if not created by other steps in same run.
        - If `version` is not None a model version will be fetched based on the version:
            - If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
            - If `version` is set to a string, the model version with the matching version will be fetched.
            - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

        Returns:
            The model version based on configuration.

        Raises:
            RuntimeError: if the model version needs to be created, but provided name is reserved
        """
        from zenml.client import Client
        from zenml.models.model_models import ModelVersionRequestModel

        model = self._get_or_create_model()

        zenml_client = Client()
        model_version_request = ModelVersionRequestModel(
            user=zenml_client.active_user.id,
            workspace=zenml_client.active_workspace.id,
            name=self.version,
            description=self.description,
            model=model.id,
        )
        mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
        try:
            if not self.version:
                try:
                    from zenml import get_step_context

                    context = get_step_context()
                except RuntimeError:
                    pass
                else:
                    # if inside a step context we loop over all
                    # model version configuration to find, if the
                    # model version for current model was already
                    # created in the current run, not to create
                    # new model versions
                    pipeline_mv = context.pipeline_run.config.model_version
                    if (
                        pipeline_mv
                        and pipeline_mv.was_created_in_this_run
                        and pipeline_mv.name == self.name
                        and pipeline_mv.version is not None
                    ):
                        self.version = pipeline_mv.version
                    else:
                        for step in context.pipeline_run.steps.values():
                            step_mv = step.config.model_version
                            if (
                                step_mv
                                and step_mv.was_created_in_this_run
                                and step_mv.name == self.name
                                and step_mv.version is not None
                            ):
                                self.version = step_mv.version
                                break
            if self.version:
                model_version = self._get_model_version()
            else:
                raise KeyError
        except KeyError:
            if (
                self.version
                and str(self.version).lower() in ModelStages.values()
            ):
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "it matches one of the possible model version stages. If you "
                    "are aiming to fetch model version by stage, check if the "
                    "model version in given stage exists. It might be missing, if "
                    "the pipeline promoting model version to this stage failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list {self.name}` CLI command."
                )
            if str(self.version).isnumeric():
                raise RuntimeError(
                    f"Cannot create a model version named {str(self.version)} as "
                    "numeric model version names are reserved. If you "
                    "are aiming to fetch model version by number, check if the "
                    "model version with given number exists. It might be missing, if "
                    "the pipeline creating model version failed,"
                    " as an example. You can explore model versions using "
                    f"`zenml model version list {self.name}` CLI command."
                )
            model_version = zenml_client.zen_store.create_model_version(
                model_version=mv_request
            )
            self.version = model_version.name
            self.was_created_in_this_run = True
            logger.info(f"New model version `{self.version}` was created.")
        self._id = model_version.id
        self._model_id = model_version.model.id
        self._number = model_version.number
        return model_version

    def _merge(self, model_version: "ModelVersion") -> None:
        self.license = self.license or model_version.license
        self.description = self.description or model_version.description
        self.audience = self.audience or model_version.audience
        self.use_cases = self.use_cases or model_version.use_cases
        self.limitations = self.limitations or model_version.limitations
        self.trade_offs = self.trade_offs or model_version.trade_offs
        self.ethics = self.ethics or model_version.ethics
        if model_version.tags is not None:
            self.tags = list(
                {t for t in self.tags or []}.union(set(model_version.tags))
            )

    def __hash__(self) -> int:
        """Get hash of the `ModelVersion`.

        Returns:
            Hash function results
        """
        return hash(
            "::".join(
                (
                    str(v)
                    for v in (
                        self.name,
                        self.version,
                    )
                )
            )
        )
id: UUID property readonly

Get version id from the Model Control Plane.

Returns:

Type Description
UUID

ID of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

model_id: UUID property readonly

Get model id from the Model Control Plane.

Returns:

Type Description
UUID

The UUID of the model containing this model version.

number: int property readonly

Get version number from the Model Control Plane.

Returns:

Type Description
int

Number of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

stage: Optional[zenml.enums.ModelStages] property readonly

Get version stage from the Model Control Plane.

Returns:

Type Description
Optional[zenml.enums.ModelStages]

Stage of the model version or None, if model version doesn't exist and can only be read given current config (you used stage name or number as a version name).

Config

Config class.

Source code in zenml/model/model_version.py
class Config:
    """Config class."""

    smart_union = True
__eq__(self, other) special

Check two ModelVersions for equality.

Parameters:

Name Type Description Default
other object

object to compare with

required

Returns:

Type Description
bool

True, if equal, False otherwise.

Source code in zenml/model/model_version.py
def __eq__(self, other: object) -> bool:
    """Check two ModelVersions for equality.

    Args:
        other: object to compare with

    Returns:
        True, if equal, False otherwise.
    """
    if not isinstance(other, ModelVersion):
        return NotImplemented
    if self.name != other.name:
        return False
    if self.name == other.name and self.version == other.version:
        return True
    self_mv = self._get_or_create_model_version()
    other_mv = other._get_or_create_model_version()
    return self_mv.id == other_mv.id
__hash__(self) special

Get hash of the ModelVersion.

Returns:

Type Description
int

Hash function results

Source code in zenml/model/model_version.py
def __hash__(self) -> int:
    """Get hash of the `ModelVersion`.

    Returns:
        Hash function results
    """
    return hash(
        "::".join(
            (
                str(v)
                for v in (
                    self.name,
                    self.version,
                )
            )
        )
    )
get_data_artifact(self, name, version=None, pipeline_name=None, step_name=None)

Get the data artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the data artifact to retrieve.

required
version Optional[str]

The version of the data artifact to retrieve (None for latest/non-versioned)

None
pipeline_name Optional[str]

The name of the pipeline generated the data artifact.

None
step_name Optional[str]

The name of the step generated the data artifact.

None

Returns:

Type Description
Optional[ArtifactResponse]

Specific version of the data artifact or None

Source code in zenml/model/model_version.py
def get_data_artifact(
    self,
    name: str,
    version: Optional[str] = None,
    pipeline_name: Optional[str] = None,
    step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
    """Get the data artifact linked to this model version.

    Args:
        name: The name of the data artifact to retrieve.
        version: The version of the data artifact to retrieve (None for latest/non-versioned)
        pipeline_name: The name of the pipeline generated the data artifact.
        step_name: The name of the step generated the data artifact.

    Returns:
        Specific version of the data artifact or None
    """
    return self._get_or_create_model_version().get_data_artifact(
        name=name,
        version=version,
        pipeline_name=pipeline_name,
        step_name=step_name,
    )
get_endpoint_artifact(self, name, version=None, pipeline_name=None, step_name=None)

Get the endpoint artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the endpoint artifact to retrieve.

required
version Optional[str]

The version of the endpoint artifact to retrieve (None for latest/non-versioned)

None
pipeline_name Optional[str]

The name of the pipeline generated the endpoint artifact.

None
step_name Optional[str]

The name of the step generated the endpoint artifact.

None

Returns:

Type Description
Optional[ArtifactResponse]

Specific version of the endpoint artifact or None

Source code in zenml/model/model_version.py
def get_endpoint_artifact(
    self,
    name: str,
    version: Optional[str] = None,
    pipeline_name: Optional[str] = None,
    step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
    """Get the endpoint artifact linked to this model version.

    Args:
        name: The name of the endpoint artifact to retrieve.
        version: The version of the endpoint artifact to retrieve (None for latest/non-versioned)
        pipeline_name: The name of the pipeline generated the endpoint artifact.
        step_name: The name of the step generated the endpoint artifact.

    Returns:
        Specific version of the endpoint artifact or None
    """
    return self._get_or_create_model_version().get_endpoint_artifact(
        name=name,
        version=version,
        pipeline_name=pipeline_name,
        step_name=step_name,
    )
get_model_artifact(self, name, version=None, pipeline_name=None, step_name=None)

Get the model artifact linked to this model version.

Parameters:

Name Type Description Default
name str

The name of the model artifact to retrieve.

required
version Optional[str]

The version of the model artifact to retrieve (None for latest/non-versioned)

None
pipeline_name Optional[str]

The name of the pipeline-generated the model artifact.

None
step_name Optional[str]

The name of the step-generated the model artifact.

None

Returns:

Type Description
Optional[ArtifactResponse]

Specific version of the model artifact or None

Source code in zenml/model/model_version.py
def get_model_artifact(
    self,
    name: str,
    version: Optional[str] = None,
    pipeline_name: Optional[str] = None,
    step_name: Optional[str] = None,
) -> Optional["ArtifactResponse"]:
    """Get the model artifact linked to this model version.

    Args:
        name: The name of the model artifact to retrieve.
        version: The version of the model artifact to retrieve (None for latest/non-versioned)
        pipeline_name: The name of the pipeline-generated the model artifact.
        step_name: The name of the step-generated the model artifact.

    Returns:
        Specific version of the model artifact or None
    """
    return self._get_or_create_model_version().get_model_artifact(
        name=name,
        version=version,
        pipeline_name=pipeline_name,
        step_name=step_name,
    )
get_pipeline_run(self, name)

Get pipeline run linked to this version.

Parameters:

Name Type Description Default
name str

The name of the pipeline run to retrieve.

required

Returns:

Type Description
PipelineRunResponse

PipelineRun as PipelineRunResponse

Source code in zenml/model/model_version.py
def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
    """Get pipeline run linked to this version.

    Args:
        name: The name of the pipeline run to retrieve.

    Returns:
        PipelineRun as PipelineRunResponse
    """
    return self._get_or_create_model_version().get_pipeline_run(name=name)
set_stage(self, stage, force=False)

Sets this Model Version to a desired stage.

Parameters:

Name Type Description Default
stage Union[str, zenml.enums.ModelStages]

the target stage for model version.

required
force bool

whether to force archiving of current model version in target stage or raise.

False

Returns:

Type Description
ModelVersion

Updated Model Version object.

Source code in zenml/model/model_version.py
def set_stage(
    self, stage: Union[str, ModelStages], force: bool = False
) -> "ModelVersion":
    """Sets this Model Version to a desired stage.

    Args:
        stage: the target stage for model version.
        force: whether to force archiving of current model version in target stage or raise.

    Returns:
        Updated Model Version object.
    """
    return self._get_or_create_model_version().set_stage(
        stage=stage, force=force
    )