Skip to content

Steps

zenml.steps special

Initializer for ZenML steps.

A step is a single piece or stage of a ZenML pipeline. Think of each step as being one of the nodes of a Directed Acyclic Graph (or DAG). Steps are responsible for one aspect of processing or interacting with the data / artifacts in the pipeline.

ZenML currently implements a basic step interface, but there will be other more customized interfaces (layered in a hierarchy) for specialized implementations. Conceptually, a Step is a discrete and independent part of a pipeline that is responsible for one particular aspect of data manipulation inside a ZenML pipeline.

Steps can be subclassed from the BaseStep class, or used via our @step decorator.

base_parameters

Step parameters.

BaseParameters (BaseModel) pydantic-model

Base class to pass parameters into a step.

Source code in zenml/steps/base_parameters.py
class BaseParameters(BaseModel):
    """Base class to pass parameters into a step."""

base_step

Base Step for ZenML.

BaseStep

Abstract base class for all ZenML steps.

Attributes:

Name Type Description
name

The name of this step.

pipeline_parameter_name Optional[str]

The name of the pipeline parameter for which this step was passed as an argument.

enable_cache

A boolean indicating if caching is enabled for this step.

Source code in zenml/steps/base_step.py
class BaseStep(metaclass=BaseStepMeta):
    """Abstract base class for all ZenML steps.

    Attributes:
        name: The name of this step.
        pipeline_parameter_name: The name of the pipeline parameter for which
            this step was passed as an argument.
        enable_cache: A boolean indicating if caching is enabled for this step.
    """

    INPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None  # type: ignore[assignment] # noqa
    OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None  # type: ignore[assignment] # noqa
    PARAMETERS_FUNCTION_PARAMETER_NAME: ClassVar[Optional[str]] = None
    PARAMETERS_CLASS: ClassVar[Optional[Type["BaseParameters"]]] = None
    CONTEXT_PARAMETER_NAME: ClassVar[Optional[str]] = None

    INSTANCE_CONFIGURATION: Dict[str, Any] = {}

    class _OutputArtifact(NamedTuple):
        """Internal step output artifact.

        This class is used for inputs/outputs of the __call__ method of
        BaseStep. It passes all the information about step outputs so downstream
        steps can finalize their configuration.

        Attributes:
            channel: TFX channel that defines the artifact class and id of the
                step that produced the output.
            materializer_source: The source of the materializer used to
                write the output.
        """

        channel: Channel
        materializer_source: str

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initializes a step.

        Args:
            *args: Positional arguments passed to the step.
            **kwargs: Keyword arguments passed to the step.
        """
        self.pipeline_parameter_name: Optional[str] = None
        self._component: Optional["_SimpleComponent"] = None
        self._has_been_called = False
        self._upstream_steps: Set[str] = set()

        kwargs = {**self.INSTANCE_CONFIGURATION, **kwargs}
        name = kwargs.pop(PARAM_STEP_NAME, None) or self.__class__.__name__

        # This value is only used in `BaseStep.__created_by_functional_api()`
        kwargs.pop(PARAM_CREATED_BY_FUNCTIONAL_API, None)

        requires_context = bool(self.CONTEXT_PARAMETER_NAME)
        enable_cache = kwargs.pop(PARAM_ENABLE_CACHE, None)
        if enable_cache is None:
            if requires_context:
                # Using the StepContext inside a step provides access to
                # external resources which might influence the step execution.
                # We therefore disable caching unless it is explicitly enabled
                enable_cache = False
                logger.debug(
                    "Step '%s': Step context required and caching not "
                    "explicitly enabled.",
                    name,
                )
            else:
                # Default to cache enabled if not explicitly set
                enable_cache = True

        logger.debug(
            "Step '%s': Caching %s.",
            name,
            "enabled" if enable_cache else "disabled",
        )

        self._configuration = PartialStepConfiguration(
            name=name,
            enable_cache=enable_cache,
        )
        self._apply_class_configuration(kwargs)
        self._verify_and_apply_init_params(*args, **kwargs)

    @abstractmethod
    def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
        """Abstract method for core step logic.

        Args:
            *args: Positional arguments passed to the step.
            **kwargs: Keyword arguments passed to the step.

        Returns:
            The output of the step.
        """

    @classmethod
    def _created_by_functional_api(cls) -> bool:
        """Returns if the step class was created by the functional API.

        Returns:
            `True` if the class was created by the functional API,
            `False` otherwise.
        """
        return cls.INSTANCE_CONFIGURATION.get(
            PARAM_CREATED_BY_FUNCTIONAL_API, False
        )

    @property
    def upstream_steps(self) -> Set[str]:
        """Names of the upstream steps of this step.

        This property will only contain the full set of upstream steps once
        it's parent pipeline `connect(...)` method was called.

        Returns:
            Set of upstream step names.
        """
        return self._upstream_steps

    def after(self, step: "BaseStep") -> None:
        """Adds an upstream step to this step.

        Calling this method makes sure this step only starts running once the
        given step has successfully finished executing.

        **Note**: This can only be called inside the pipeline connect function
        which is decorated with the `@pipeline` decorator. Any calls outside
        this function will be ignored.

        Example:
        The following pipeline will run its steps sequentially in the following
        order: step_2 -> step_1 -> step_3

        ```python
        @pipeline
        def example_pipeline(step_1, step_2, step_3):
            step_1.after(step_2)
            step_3(step_1(), step_2())
        ```

        Args:
            step: A step which should finish executing before this step is
                started.
        """
        self._upstream_steps.add(step.name)

    @property
    def _internal_execution_parameters(self) -> Dict[str, Any]:
        """Internal ZenML execution parameters for this step.

        Returns:
            A dictionary containing the ZenML internal execution parameters
        """
        parameters = {
            PARAM_PIPELINE_PARAMETER_NAME: self.pipeline_parameter_name,
        }

        if self.enable_cache:
            # Caching is enabled so we compute a hash of the step function code
            # and materializers to catch changes in the step behavior

            # If the step was defined using the functional api, only track
            # changes to the entrypoint function. Otherwise track changes to
            # the entire step class.
            source_object = (
                self.entrypoint
                if self._created_by_functional_api()
                else self.__class__
            )
            parameters["step_source"] = source_utils.get_hashed_source(
                source_object
            )

            for name, output in self.configuration.outputs.items():
                if output.materializer_source:
                    key = f"{name}_materializer_source"
                    materializer_class = source_utils.load_source_path_class(
                        output.materializer_source
                    )
                    parameters[key] = source_utils.get_hashed_source(
                        materializer_class
                    )
        else:
            # Add a random string to the execution properties to disable caching
            random_string = f"{random.getrandbits(128):032x}"
            parameters["disable_cache"] = random_string

        return {
            INTERNAL_EXECUTION_PARAMETER_PREFIX + key: value
            for key, value in parameters.items()
        }

    def _apply_class_configuration(self, options: Dict[str, Any]) -> None:
        """Applies the configurations specified on the step class.

        Args:
            options: Class configurations.
        """
        step_operator = options.pop(PARAM_STEP_OPERATOR, None)
        settings = options.pop(PARAM_SETTINGS, None) or {}
        output_materializers = options.pop(PARAM_OUTPUT_MATERIALIZERS, None)
        output_artifacts = options.pop(PARAM_OUTPUT_ARTIFACTS, None)
        extra = options.pop(PARAM_EXTRA_OPTIONS, None)
        experiment_tracker = options.pop(PARAM_EXPERIMENT_TRACKER, None)

        self.configure(
            experiment_tracker=experiment_tracker,
            step_operator=step_operator,
            output_artifacts=output_artifacts,
            output_materializers=output_materializers,
            settings=settings,
            extra=extra,
        )

    def _verify_and_apply_init_params(self, *args: Any, **kwargs: Any) -> None:
        """Verifies the initialization args and kwargs of this step.

        This method makes sure that there is only one parameters object passed
        at initialization and that it was passed using the correct name and
        type specified in the step declaration.

        Args:
            *args: The args passed to the init method of this step.
            **kwargs: The kwargs passed to the init method of this step.

        Raises:
            StepInterfaceError: If there are too many arguments or arguments
                with a wrong name/type.
        """
        maximum_arg_count = 1 if self.PARAMETERS_CLASS else 0
        arg_count = len(args) + len(kwargs)
        if arg_count > maximum_arg_count:
            raise StepInterfaceError(
                f"Too many arguments ({arg_count}, expected: "
                f"{maximum_arg_count}) passed when creating a "
                f"'{self.name}' step."
            )

        if self.PARAMETERS_FUNCTION_PARAMETER_NAME and self.PARAMETERS_CLASS:
            if args:
                config = args[0]
            elif kwargs:
                key, config = kwargs.popitem()

                if key != self.PARAMETERS_FUNCTION_PARAMETER_NAME:
                    raise StepInterfaceError(
                        f"Unknown keyword argument '{key}' when creating a "
                        f"'{self.name}' step, only expected a single "
                        "argument with key "
                        f"'{self.PARAMETERS_FUNCTION_PARAMETER_NAME}'."
                    )
            else:
                # This step requires configuration parameters but no parameters
                # object was passed as an argument. The parameters might be
                # set via default values in the parameters class or in a
                # configuration file, so we continue for now and verify
                # that all parameters are set before running the step
                return

            if not isinstance(config, self.PARAMETERS_CLASS):
                raise StepInterfaceError(
                    f"`{config}` object passed when creating a "
                    f"'{self.name}' step is not a "
                    f"`{self.PARAMETERS_CLASS.__name__}` instance."
                )

            self.configure(parameters=config)

    def _validate_input_artifacts(
        self, *artifacts: _OutputArtifact, **kw_artifacts: _OutputArtifact
    ) -> Dict[str, _OutputArtifact]:
        """Verifies and prepares the input artifacts for running this step.

        Args:
            *artifacts: Positional input artifacts passed to
                the __call__ method.
            **kw_artifacts: Keyword input artifacts passed to
                the __call__ method.

        Returns:
            Dictionary containing both the positional and keyword input
            artifacts.

        Raises:
            StepInterfaceError: If there are too many or too few artifacts.
        """
        input_artifact_keys = list(self.INPUT_SIGNATURE.keys())
        if len(artifacts) > len(input_artifact_keys):
            raise StepInterfaceError(
                f"Too many input artifacts for step '{self.name}'. "
                f"This step expects {len(input_artifact_keys)} artifact(s) "
                f"but got {len(artifacts) + len(kw_artifacts)}."
            )

        combined_artifacts = {}

        for i, artifact in enumerate(artifacts):
            if not isinstance(artifact, BaseStep._OutputArtifact):
                raise StepInterfaceError(
                    f"Wrong argument type (`{type(artifact)}`) for positional "
                    f"argument {i} of step '{self.name}'. Only outputs "
                    f"from previous steps can be used as arguments when "
                    f"connecting steps."
                )

            key = input_artifact_keys[i]
            combined_artifacts[key] = artifact

        for key, artifact in kw_artifacts.items():
            if key in combined_artifacts:
                # an artifact for this key was already set by
                # the positional input artifacts
                raise StepInterfaceError(
                    f"Unexpected keyword argument '{key}' for step "
                    f"'{self.name}'. An artifact for this key was "
                    f"already passed as a positional argument."
                )

            if not isinstance(artifact, BaseStep._OutputArtifact):
                raise StepInterfaceError(
                    f"Wrong argument type (`{type(artifact)}`) for argument "
                    f"'{key}' of step '{self.name}'. Only outputs from "
                    f"previous steps can be used as arguments when "
                    f"connecting steps."
                )

            combined_artifacts[key] = artifact

        # check if there are any missing or unexpected artifacts
        expected_artifacts = set(self.INPUT_SIGNATURE.keys())
        actual_artifacts = set(combined_artifacts.keys())
        missing_artifacts = expected_artifacts - actual_artifacts
        unexpected_artifacts = actual_artifacts - expected_artifacts

        if missing_artifacts:
            raise StepInterfaceError(
                f"Missing input artifact(s) for step "
                f"'{self.name}': {missing_artifacts}."
            )

        if unexpected_artifacts:
            raise StepInterfaceError(
                f"Unexpected input artifact(s) for step "
                f"'{self.name}': {unexpected_artifacts}. This step "
                f"only requires the following artifacts: {expected_artifacts}."
            )

        return combined_artifacts

    def __call__(
        self, *artifacts: _OutputArtifact, **kw_artifacts: _OutputArtifact
    ) -> Union[_OutputArtifact, List[_OutputArtifact]]:
        """Generates a component when called.

        Args:
            *artifacts: Positional input artifacts passed to
                the __call__ method.
            **kw_artifacts: Keyword input artifacts passed to
                the __call__ method.

        Returns:
            A single output artifact or a list of output artifacts.

        Raises:
            StepInterfaceError: If the step has already been called.
        """
        if self._has_been_called:
            raise StepInterfaceError(
                f"Step {self.name} has already been called. A ZenML step "
                f"instance can only be called once per pipeline run."
            )
        self._has_been_called = True

        # Prepare the input artifacts and spec
        input_artifacts = self._validate_input_artifacts(
            *artifacts, **kw_artifacts
        )
        input_channels = {
            name: artifact.channel for name, artifact in input_artifacts.items()
        }
        for input_ in input_artifacts.values():
            self._upstream_steps.add(input_.channel.producer_component_id)

        config = self._finalize_configuration(input_artifacts=input_artifacts)

        execution_parameters = {
            **self.configuration.parameters,
            **self._internal_execution_parameters,
        }

        # Convert execution parameter values to strings
        try:
            execution_parameters = {
                k: json.dumps(v) for k, v in execution_parameters.items()
            }
        except TypeError as e:
            raise StepInterfaceError(
                f"Failed to serialize execution parameters for step "
                f"'{self.name}'. Please make sure to only use "
                f"json serializable parameter values."
            ) from e

        component_class = create_component_class(step=self)
        self._component = component_class(
            **input_channels, **execution_parameters
        )

        # Resolve the returns in the right order.
        returns = []
        for key in self.OUTPUT_SIGNATURE:
            materializer_source = config.outputs[key].materializer_source
            output_artifact = BaseStep._OutputArtifact(
                channel=cast(Channel, self.component.outputs[key]),
                materializer_source=materializer_source,
            )
            returns.append(output_artifact)

        # If its one return we just return the one channel not as a list
        if len(returns) == 1:
            return returns[0]
        else:
            return returns

    @property
    def component(self) -> "_SimpleComponent":
        """Returns a TFX component.

        Returns:
            A TFX component.

        Raises:
            StepInterfaceError: If you are trying to access the step component
                before creating it.
        """
        if not self._component:
            raise StepInterfaceError(
                "Trying to access the step component "
                "before creating it via calling the step."
            )
        return self._component

    def with_return_materializers(
        self: T,
        materializers: Union[
            Type[BaseMaterializer], Dict[str, Type[BaseMaterializer]]
        ],
    ) -> T:
        """DEPRECATED: Register materializers for step outputs.

        If a single materializer is passed, it will be used for all step
        outputs. Otherwise, the dictionary keys specify the output names
        for which the materializers will be used.

        Args:
            materializers: The materializers for the outputs of this step.

        Returns:
            The step that this method was called on.
        """
        logger.warning(
            "The `with_return_materializers(...)` method is deprecated. "
            "Use `step.configure(output_materializers=...)` instead."
        )

        self.configure(output_materializers=materializers)
        return self

    @property
    def name(self) -> str:
        """The name of the step.

        Returns:
            The name of the step.
        """
        return self.configuration.name

    @property
    def enable_cache(self) -> bool:
        """If caching is enabled for the step.

        Returns:
            If caching is enabled for the step.
        """
        return self.configuration.enable_cache

    @property
    def configuration(self) -> PartialStepConfiguration:
        """The configuration of the step.

        Returns:
            The configuration of the step.
        """
        return self._configuration

    def configure(
        self: T,
        name: Optional[str] = None,
        enable_cache: Optional[bool] = None,
        experiment_tracker: Optional[str] = None,
        step_operator: Optional[str] = None,
        parameters: Optional["ParametersOrDict"] = None,
        output_materializers: Optional[
            "OutputMaterializersSpecification"
        ] = None,
        output_artifacts: Optional["OutputArtifactsSpecification"] = None,
        settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
        extra: Optional[Dict[str, Any]] = None,
        merge: bool = True,
    ) -> T:
        """Configures the step.

        Configuration merging example:
        * `merge==True`:
            step.configure(extra={"key1": 1})
            step.configure(extra={"key2": 2}, merge=True)
            step.configuration.extra # {"key1": 1, "key2": 2}
        * `merge==False`:
            step.configure(extra={"key1": 1})
            step.configure(extra={"key2": 2}, merge=False)
            step.configuration.extra # {"key2": 2}

        Args:
            name: The name of the step.
            enable_cache: If caching should be enabled for this step.
            experiment_tracker: The experiment tracker to use for this step.
            step_operator: The step operator to use for this step.
            parameters: Function parameters for this step
            output_materializers: Output materializers for this step. If
                given as a dict, the keys must be a subset of the output names
                of this step. If a single value (type or string) is given, the
                materializer will be used for all outputs.
            output_artifacts: Output artifacts for this step. If
                given as a dict, the keys must be a subset of the output names
                of this step. If a single value (type or string) is given, the
                artifact class will be used for all outputs.
            settings: settings for this step.
            extra: Extra configurations for this step.
            merge: If `True`, will merge the given dictionary configurations
                like `parameters` and `settings` with existing
                configurations. If `False` the given configurations will
                overwrite all existing ones. See the general description of this
                method for an example.

        Returns:
            The step instance that this method was called on.

        Raises:
            StepInterfaceError: If a materializer or artifact for a non-existent
                output name are configured.
        """

        def _resolve_if_necessary(value: Union[str, Type[Any]]) -> str:
            return (
                value
                if isinstance(value, str)
                else source_utils.resolve_class(value)
            )

        outputs: Dict[str, Dict[str, str]] = defaultdict(dict)
        allowed_output_names = set(self.OUTPUT_SIGNATURE)

        if output_materializers:
            if not isinstance(output_materializers, Mapping):
                # string of materializer class to be used for all outputs
                source = _resolve_if_necessary(output_materializers)
                output_materializers = {
                    output_name: source for output_name in allowed_output_names
                }

            for output_name, materializer in output_materializers.items():
                if output_name not in allowed_output_names:
                    raise StepInterfaceError(
                        f"Got unexpected materializers for non-existent "
                        f"output '{output_name}' in step '{self.name}'. "
                        f"Only materializers for the outputs "
                        f"{allowed_output_names} of this step can"
                        f" be registered."
                    )

                source = _resolve_if_necessary(materializer)
                outputs[output_name]["materializer_source"] = source

        if output_artifacts:
            if not isinstance(output_artifacts, Mapping):
                # string of artifact class to be used for all outputs
                source = _resolve_if_necessary(output_artifacts)
                output_artifacts = {
                    output_name: source for output_name in allowed_output_names
                }

            for output_name, artifact in output_artifacts.items():
                if output_name not in allowed_output_names:
                    raise StepInterfaceError(
                        f"Got unexpected artifact for non-existent "
                        f"output '{output_name}' in step '{self.name}'. "
                        f"Only artifacts for the outputs "
                        f"{allowed_output_names} of this step can"
                        f" be registered."
                    )

                source = _resolve_if_necessary(artifact)
                outputs[output_name]["artifact_source"] = source

        values = dict_utils.remove_none_values(
            {
                "name": name,
                "enable_cache": enable_cache,
                "experiment_tracker": experiment_tracker,
                "step_operator": step_operator,
                "parameters": parameters,
                "settings": settings,
                "outputs": outputs or None,
                "extra": extra,
            }
        )
        config = StepConfigurationUpdate(**values)
        self._apply_configuration(config, merge=merge)
        return self

    def _apply_configuration(
        self,
        config: StepConfigurationUpdate,
        merge: bool = True,
    ) -> None:
        """Applies an update to the step configuration.

        Args:
            config: The configuration update.
            merge: Whether to merge the updates with the existing configuration
                or not. See the `BaseStep.configure(...)` method for a detailed
                explanation.
        """
        self._validate_configuration(config)

        self._configuration = pydantic_utils.update_model(
            self._configuration, update=config, recursive=merge
        )

        logger.debug("Updated step configuration:")
        logger.debug(self._configuration)

    def _validate_configuration(self, config: StepConfigurationUpdate) -> None:
        """Validates a configuration update.

        Args:
            config: The configuration update to validate.
        """
        settings_utils.validate_setting_keys(list(config.settings))
        self._validate_function_parameters(parameters=config.parameters)
        self._validate_outputs(outputs=config.outputs)

    def _validate_function_parameters(self, parameters: Dict[str, Any]) -> None:
        """Validates step function parameters.

        Args:
            parameters: The parameters to validate.

        Raises:
            StepInterfaceError: If the step requires no function parameters or
                invalid function parameters were given.
        """
        if not parameters:
            # No parameters set (yet), defer validation to a later point
            return

        if not self.PARAMETERS_CLASS:
            raise StepInterfaceError(
                f"Function parameters configured for step {self.name} which "
                "does not accept any function parameters."
            )

        try:
            self.PARAMETERS_CLASS(**parameters)
        except ValidationError:
            raise StepInterfaceError("Failed to validate function parameters.")

    def _validate_inputs(
        self, inputs: Mapping[str, ArtifactConfiguration]
    ) -> None:
        """Validates the step input configuration.

        Args:
            inputs: The configured step inputs.

        Raises:
            StepInterfaceError: If an input for a non-existent name is
                configured of an input artifact source does not resolve to a
                BaseArtifact subclass.
        """
        allowed_input_names = set(self.INPUT_SIGNATURE)
        for input_name, input_ in inputs.items():
            if input_name not in allowed_input_names:
                raise StepInterfaceError(
                    f"Got unexpected artifact for non-existent "
                    f"input '{input_name}' in step '{self.name}'. "
                    f"Only artifacts for the inputs "
                    f"{allowed_input_names} of this step can"
                    f" be registered."
                )

            if not source_utils.validate_source_class(
                input_.artifact_source, expected_class=BaseArtifact
            ):
                raise StepInterfaceError(
                    f"Artifact source `{input_.artifact_source}` "
                    f"for input '{input_name}' of step '{self.name}' "
                    "does not resolve to a `BaseArtifact` subclass."
                )

    def _validate_outputs(
        self, outputs: Mapping[str, PartialArtifactConfiguration]
    ) -> None:
        """Validates the step output configuration.

        Args:
            outputs: The configured step outputs.

        Raises:
            StepInterfaceError: If an output for a non-existent name is
                configured of an output artifact/materializer source does not
                resolve to the correct class.
        """
        allowed_output_names = set(self.OUTPUT_SIGNATURE)
        for output_name, output in outputs.items():
            if output_name not in allowed_output_names:
                raise StepInterfaceError(
                    f"Found explicit artifact type for unrecognized output "
                    f"'{output_name}' in step '{self.name}'. Output "
                    f"artifact types can only be specified for the outputs "
                    f"of this step: {set(self.OUTPUT_SIGNATURE)}."
                )

            if output.materializer_source:
                if not source_utils.validate_source_class(
                    output.materializer_source, expected_class=BaseMaterializer
                ):
                    raise StepInterfaceError(
                        f"Materializer source `{output.materializer_source}` "
                        f"for output '{output_name}' of step '{self.name}' "
                        "does not resolve to a  `BaseMaterializer` subclass."
                    )

            if output.artifact_source:
                try:
                    artifact_class: Type[
                        BaseArtifact
                    ] = source_utils.load_and_validate_class(
                        output.artifact_source, expected_class=BaseArtifact
                    )
                except TypeError:
                    raise StepInterfaceError(
                        f"Artifact source `{output.artifact_source}` "
                        f"for output '{output_name}' of step '{self.name}' "
                        "does not point to a  `BaseArtifact` subclass."
                    )
                # TODO: Can we get rid of this check? Why do we limit artifact
                # types to registered materializers?
                output_type = self.OUTPUT_SIGNATURE[output_name]
                allowed_artifact_types = set(
                    type_registry.get_artifact_type(output_type)
                )

                if artifact_class not in allowed_artifact_types:
                    raise StepInterfaceError(
                        f"Artifact type `{artifact_class}` for output "
                        f"'{output_name}' of step '{self.name}' is not an "
                        f"allowed artifact type for the defined output type "
                        f"`{output_type}`. Allowed artifact types: "
                        f"{allowed_artifact_types}. If you want to extend the "
                        f"allowed artifact types, implement a custom "
                        f"`BaseMaterializer` subclass and set its "
                        f"`ASSOCIATED_ARTIFACT_TYPES` and `ASSOCIATED_TYPES` "
                        f"accordingly."
                    )

    def _finalize_configuration(
        self, input_artifacts: Dict[str, _OutputArtifact]
    ) -> StepConfiguration:
        """Finalizes the configuration after the step was called.

        Once the step was called, we know the outputs of previous steps
        and that no additional user configurations will be made. That means
        we can now collect the remaining artifact and materializer types
        as well as check for the completeness of the step function parameters.

        Args:
            input_artifacts: The input artifacts of this step.

        Returns:
            The finalized step configuration.

        Raises:
            StepInterfaceError: If an output does not have an explicit
                materializer assigned to it and there is no default
                materializer registered for the output type.
        """
        outputs: Dict[str, Dict[str, str]] = defaultdict(dict)

        for output_name, output_class in self.OUTPUT_SIGNATURE.items():
            output = self._configuration.outputs.get(
                output_name, PartialArtifactConfiguration()
            )

            if not output.artifact_source:
                artifact_class = type_registry.get_artifact_type(output_class)[
                    0
                ]
                outputs[output_name][
                    "artifact_source"
                ] = source_utils.resolve_class(artifact_class)

            if not output.materializer_source:
                if default_materializer_registry.is_registered(output_class):
                    materializer_class = default_materializer_registry[
                        output_class
                    ]
                else:
                    raise StepInterfaceError(
                        f"Unable to find materializer for output "
                        f"'{output_name}' of type `{output_class}` in step "
                        f"'{self.name}'. Please make sure to either "
                        f"explicitly set a materializer for step outputs "
                        f"using `step.with_return_materializers(...)` or "
                        f"registering a default materializer for specific "
                        f"types by subclassing `BaseMaterializer` and setting "
                        f"its `ASSOCIATED_TYPES` class variable.",
                        url="https://docs.zenml.io/advanced-guide/pipelines/materializers",
                    )
                outputs[output_name][
                    "materializer_source"
                ] = source_utils.resolve_class(materializer_class)

        function_parameters = self._finalize_function_parameters()
        values = dict_utils.remove_none_values(
            {
                "outputs": outputs or None,
                "parameters": function_parameters,
            }
        )
        config = StepConfigurationUpdate(**values)
        self._apply_configuration(config)

        inputs = {}
        for input_name, artifact in input_artifacts.items():
            artifact_source = source_utils.resolve_class(artifact.channel.type)
            inputs[input_name] = ArtifactConfiguration(
                artifact_source=artifact_source,
                materializer_source=artifact.materializer_source,
            )
        self._validate_inputs(inputs)

        self._configuration = self._configuration.copy(
            update={"inputs": inputs}
        )

        complete_configuration = StepConfiguration.parse_obj(
            self._configuration
        )
        return complete_configuration

    def _finalize_function_parameters(self) -> Dict[str, Any]:
        """Verifies and prepares the config parameters for running this step.

        When the step requires config parameters, this method:
            - checks if config parameters were set via a config object or file
            - tries to set missing config parameters from default values of the
              config class

        Returns:
            Values for the previously unconfigured function parameters.

        Raises:
            MissingStepParameterError: If no value could be found for one or
                more config parameters.
        """
        if not self.PARAMETERS_CLASS:
            return {}

        # we need to store a value for all config keys inside the
        # metadata store to make sure caching works as expected
        missing_keys = []
        values = {}
        for name, field in self.PARAMETERS_CLASS.__fields__.items():
            if name in self.configuration.parameters:
                # a value for this parameter has been set already
                continue

            if field.required:
                # this field has no default value set and therefore needs
                # to be passed via an initialized config object
                missing_keys.append(name)
            else:
                # use default value from the pydantic config class
                values[name] = field.default

        if missing_keys:
            raise MissingStepParameterError(
                self.name, missing_keys, self.PARAMETERS_CLASS
            )

        return values
component: _SimpleComponent property readonly

Returns a TFX component.

Returns:

Type Description
_SimpleComponent

A TFX component.

Exceptions:

Type Description
StepInterfaceError

If you are trying to access the step component before creating it.

configuration: PartialStepConfiguration property readonly

The configuration of the step.

Returns:

Type Description
PartialStepConfiguration

The configuration of the step.

enable_cache: bool property readonly

If caching is enabled for the step.

Returns:

Type Description
bool

If caching is enabled for the step.

name: str property readonly

The name of the step.

Returns:

Type Description
str

The name of the step.

upstream_steps: Set[str] property readonly

Names of the upstream steps of this step.

This property will only contain the full set of upstream steps once it's parent pipeline connect(...) method was called.

Returns:

Type Description
Set[str]

Set of upstream step names.

__call__(self, *artifacts, **kw_artifacts) special

Generates a component when called.

Parameters:

Name Type Description Default
*artifacts _OutputArtifact

Positional input artifacts passed to the call method.

()
**kw_artifacts _OutputArtifact

Keyword input artifacts passed to the call method.

{}

Returns:

Type Description
Union[zenml.steps.base_step.BaseStep._OutputArtifact, List[zenml.steps.base_step.BaseStep._OutputArtifact]]

A single output artifact or a list of output artifacts.

Exceptions:

Type Description
StepInterfaceError

If the step has already been called.

Source code in zenml/steps/base_step.py
def __call__(
    self, *artifacts: _OutputArtifact, **kw_artifacts: _OutputArtifact
) -> Union[_OutputArtifact, List[_OutputArtifact]]:
    """Generates a component when called.

    Args:
        *artifacts: Positional input artifacts passed to
            the __call__ method.
        **kw_artifacts: Keyword input artifacts passed to
            the __call__ method.

    Returns:
        A single output artifact or a list of output artifacts.

    Raises:
        StepInterfaceError: If the step has already been called.
    """
    if self._has_been_called:
        raise StepInterfaceError(
            f"Step {self.name} has already been called. A ZenML step "
            f"instance can only be called once per pipeline run."
        )
    self._has_been_called = True

    # Prepare the input artifacts and spec
    input_artifacts = self._validate_input_artifacts(
        *artifacts, **kw_artifacts
    )
    input_channels = {
        name: artifact.channel for name, artifact in input_artifacts.items()
    }
    for input_ in input_artifacts.values():
        self._upstream_steps.add(input_.channel.producer_component_id)

    config = self._finalize_configuration(input_artifacts=input_artifacts)

    execution_parameters = {
        **self.configuration.parameters,
        **self._internal_execution_parameters,
    }

    # Convert execution parameter values to strings
    try:
        execution_parameters = {
            k: json.dumps(v) for k, v in execution_parameters.items()
        }
    except TypeError as e:
        raise StepInterfaceError(
            f"Failed to serialize execution parameters for step "
            f"'{self.name}'. Please make sure to only use "
            f"json serializable parameter values."
        ) from e

    component_class = create_component_class(step=self)
    self._component = component_class(
        **input_channels, **execution_parameters
    )

    # Resolve the returns in the right order.
    returns = []
    for key in self.OUTPUT_SIGNATURE:
        materializer_source = config.outputs[key].materializer_source
        output_artifact = BaseStep._OutputArtifact(
            channel=cast(Channel, self.component.outputs[key]),
            materializer_source=materializer_source,
        )
        returns.append(output_artifact)

    # If its one return we just return the one channel not as a list
    if len(returns) == 1:
        return returns[0]
    else:
        return returns
__init__(self, *args, **kwargs) special

Initializes a step.

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to the step.

()
**kwargs Any

Keyword arguments passed to the step.

{}
Source code in zenml/steps/base_step.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initializes a step.

    Args:
        *args: Positional arguments passed to the step.
        **kwargs: Keyword arguments passed to the step.
    """
    self.pipeline_parameter_name: Optional[str] = None
    self._component: Optional["_SimpleComponent"] = None
    self._has_been_called = False
    self._upstream_steps: Set[str] = set()

    kwargs = {**self.INSTANCE_CONFIGURATION, **kwargs}
    name = kwargs.pop(PARAM_STEP_NAME, None) or self.__class__.__name__

    # This value is only used in `BaseStep.__created_by_functional_api()`
    kwargs.pop(PARAM_CREATED_BY_FUNCTIONAL_API, None)

    requires_context = bool(self.CONTEXT_PARAMETER_NAME)
    enable_cache = kwargs.pop(PARAM_ENABLE_CACHE, None)
    if enable_cache is None:
        if requires_context:
            # Using the StepContext inside a step provides access to
            # external resources which might influence the step execution.
            # We therefore disable caching unless it is explicitly enabled
            enable_cache = False
            logger.debug(
                "Step '%s': Step context required and caching not "
                "explicitly enabled.",
                name,
            )
        else:
            # Default to cache enabled if not explicitly set
            enable_cache = True

    logger.debug(
        "Step '%s': Caching %s.",
        name,
        "enabled" if enable_cache else "disabled",
    )

    self._configuration = PartialStepConfiguration(
        name=name,
        enable_cache=enable_cache,
    )
    self._apply_class_configuration(kwargs)
    self._verify_and_apply_init_params(*args, **kwargs)
after(self, step)

Adds an upstream step to this step.

Calling this method makes sure this step only starts running once the given step has successfully finished executing.

Note: This can only be called inside the pipeline connect function which is decorated with the @pipeline decorator. Any calls outside this function will be ignored.

Examples:

The following pipeline will run its steps sequentially in the following order: step_2 -> step_1 -> step_3

@pipeline
def example_pipeline(step_1, step_2, step_3):
    step_1.after(step_2)
    step_3(step_1(), step_2())

Parameters:

Name Type Description Default
step BaseStep

A step which should finish executing before this step is started.

required
Source code in zenml/steps/base_step.py
def after(self, step: "BaseStep") -> None:
    """Adds an upstream step to this step.

    Calling this method makes sure this step only starts running once the
    given step has successfully finished executing.

    **Note**: This can only be called inside the pipeline connect function
    which is decorated with the `@pipeline` decorator. Any calls outside
    this function will be ignored.

    Example:
    The following pipeline will run its steps sequentially in the following
    order: step_2 -> step_1 -> step_3

    ```python
    @pipeline
    def example_pipeline(step_1, step_2, step_3):
        step_1.after(step_2)
        step_3(step_1(), step_2())
    ```

    Args:
        step: A step which should finish executing before this step is
            started.
    """
    self._upstream_steps.add(step.name)
configure(self, name=None, enable_cache=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, output_artifacts=None, settings=None, extra=None, merge=True)

Configures the step.

Configuration merging example: * merge==True: step.configure(extra={"key1": 1}) step.configure(extra={"key2": 2}, merge=True) step.configuration.extra # {"key1": 1, "key2": 2} * merge==False: step.configure(extra={"key1": 1}) step.configure(extra={"key2": 2}, merge=False) step.configuration.extra # {"key2": 2}

Parameters:

Name Type Description Default
name Optional[str]

The name of the step.

None
enable_cache Optional[bool]

If caching should be enabled for this step.

None
experiment_tracker Optional[str]

The experiment tracker to use for this step.

None
step_operator Optional[str]

The step operator to use for this step.

None
parameters Optional[ParametersOrDict]

Function parameters for this step

None
output_materializers Optional[OutputMaterializersSpecification]

Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs.

None
output_artifacts Optional[OutputArtifactsSpecification]

Output artifacts for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the artifact class will be used for all outputs.

None
settings Optional[Mapping[str, SettingsOrDict]]

settings for this step.

None
extra Optional[Dict[str, Any]]

Extra configurations for this step.

None
merge bool

If True, will merge the given dictionary configurations like parameters and settings with existing configurations. If False the given configurations will overwrite all existing ones. See the general description of this method for an example.

True

Returns:

Type Description
~T

The step instance that this method was called on.

Exceptions:

Type Description
StepInterfaceError

If a materializer or artifact for a non-existent output name are configured.

Source code in zenml/steps/base_step.py
def configure(
    self: T,
    name: Optional[str] = None,
    enable_cache: Optional[bool] = None,
    experiment_tracker: Optional[str] = None,
    step_operator: Optional[str] = None,
    parameters: Optional["ParametersOrDict"] = None,
    output_materializers: Optional[
        "OutputMaterializersSpecification"
    ] = None,
    output_artifacts: Optional["OutputArtifactsSpecification"] = None,
    settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
    extra: Optional[Dict[str, Any]] = None,
    merge: bool = True,
) -> T:
    """Configures the step.

    Configuration merging example:
    * `merge==True`:
        step.configure(extra={"key1": 1})
        step.configure(extra={"key2": 2}, merge=True)
        step.configuration.extra # {"key1": 1, "key2": 2}
    * `merge==False`:
        step.configure(extra={"key1": 1})
        step.configure(extra={"key2": 2}, merge=False)
        step.configuration.extra # {"key2": 2}

    Args:
        name: The name of the step.
        enable_cache: If caching should be enabled for this step.
        experiment_tracker: The experiment tracker to use for this step.
        step_operator: The step operator to use for this step.
        parameters: Function parameters for this step
        output_materializers: Output materializers for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            materializer will be used for all outputs.
        output_artifacts: Output artifacts for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            artifact class will be used for all outputs.
        settings: settings for this step.
        extra: Extra configurations for this step.
        merge: If `True`, will merge the given dictionary configurations
            like `parameters` and `settings` with existing
            configurations. If `False` the given configurations will
            overwrite all existing ones. See the general description of this
            method for an example.

    Returns:
        The step instance that this method was called on.

    Raises:
        StepInterfaceError: If a materializer or artifact for a non-existent
            output name are configured.
    """

    def _resolve_if_necessary(value: Union[str, Type[Any]]) -> str:
        return (
            value
            if isinstance(value, str)
            else source_utils.resolve_class(value)
        )

    outputs: Dict[str, Dict[str, str]] = defaultdict(dict)
    allowed_output_names = set(self.OUTPUT_SIGNATURE)

    if output_materializers:
        if not isinstance(output_materializers, Mapping):
            # string of materializer class to be used for all outputs
            source = _resolve_if_necessary(output_materializers)
            output_materializers = {
                output_name: source for output_name in allowed_output_names
            }

        for output_name, materializer in output_materializers.items():
            if output_name not in allowed_output_names:
                raise StepInterfaceError(
                    f"Got unexpected materializers for non-existent "
                    f"output '{output_name}' in step '{self.name}'. "
                    f"Only materializers for the outputs "
                    f"{allowed_output_names} of this step can"
                    f" be registered."
                )

            source = _resolve_if_necessary(materializer)
            outputs[output_name]["materializer_source"] = source

    if output_artifacts:
        if not isinstance(output_artifacts, Mapping):
            # string of artifact class to be used for all outputs
            source = _resolve_if_necessary(output_artifacts)
            output_artifacts = {
                output_name: source for output_name in allowed_output_names
            }

        for output_name, artifact in output_artifacts.items():
            if output_name not in allowed_output_names:
                raise StepInterfaceError(
                    f"Got unexpected artifact for non-existent "
                    f"output '{output_name}' in step '{self.name}'. "
                    f"Only artifacts for the outputs "
                    f"{allowed_output_names} of this step can"
                    f" be registered."
                )

            source = _resolve_if_necessary(artifact)
            outputs[output_name]["artifact_source"] = source

    values = dict_utils.remove_none_values(
        {
            "name": name,
            "enable_cache": enable_cache,
            "experiment_tracker": experiment_tracker,
            "step_operator": step_operator,
            "parameters": parameters,
            "settings": settings,
            "outputs": outputs or None,
            "extra": extra,
        }
    )
    config = StepConfigurationUpdate(**values)
    self._apply_configuration(config, merge=merge)
    return self
entrypoint(self, *args, **kwargs)

Abstract method for core step logic.

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to the step.

()
**kwargs Any

Keyword arguments passed to the step.

{}

Returns:

Type Description
Any

The output of the step.

Source code in zenml/steps/base_step.py
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
    """Abstract method for core step logic.

    Args:
        *args: Positional arguments passed to the step.
        **kwargs: Keyword arguments passed to the step.

    Returns:
        The output of the step.
    """
with_return_materializers(self, materializers)

DEPRECATED: Register materializers for step outputs.

If a single materializer is passed, it will be used for all step outputs. Otherwise, the dictionary keys specify the output names for which the materializers will be used.

Parameters:

Name Type Description Default
materializers Union[Type[zenml.materializers.base_materializer.BaseMaterializer], Dict[str, Type[zenml.materializers.base_materializer.BaseMaterializer]]]

The materializers for the outputs of this step.

required

Returns:

Type Description
~T

The step that this method was called on.

Source code in zenml/steps/base_step.py
def with_return_materializers(
    self: T,
    materializers: Union[
        Type[BaseMaterializer], Dict[str, Type[BaseMaterializer]]
    ],
) -> T:
    """DEPRECATED: Register materializers for step outputs.

    If a single materializer is passed, it will be used for all step
    outputs. Otherwise, the dictionary keys specify the output names
    for which the materializers will be used.

    Args:
        materializers: The materializers for the outputs of this step.

    Returns:
        The step that this method was called on.
    """
    logger.warning(
        "The `with_return_materializers(...)` method is deprecated. "
        "Use `step.configure(output_materializers=...)` instead."
    )

    self.configure(output_materializers=materializers)
    return self

BaseStepMeta (type)

Metaclass for BaseStep.

Checks whether everything passed in: * Has a matching materializer, * Is a subclass of the Config class, * Is typed correctly.

Source code in zenml/steps/base_step.py
class BaseStepMeta(type):
    """Metaclass for `BaseStep`.

    Checks whether everything passed in:
    * Has a matching materializer,
    * Is a subclass of the Config class,
    * Is typed correctly.
    """

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "BaseStepMeta":
        """Set up a new class with a qualified spec.

        Args:
            name: The name of the class.
            bases: The base classes of the class.
            dct: The attributes of the class.

        Returns:
            The new class.

        Raises:
            StepInterfaceError: When unable to create the step.
        """
        from zenml.steps.base_parameters import BaseParameters

        dct.setdefault(INSTANCE_CONFIGURATION, {})
        cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))

        cls.INPUT_SIGNATURE = {}
        cls.OUTPUT_SIGNATURE = {}
        cls.PARAMETERS_FUNCTION_PARAMETER_NAME = None
        cls.PARAMETERS_CLASS = None
        cls.CONTEXT_PARAMETER_NAME = None

        # Get the signature of the step function
        step_function_signature = inspect.getfullargspec(
            inspect.unwrap(cls.entrypoint)
        )

        if bases:
            # We're not creating the abstract `BaseStep` class
            # but a concrete implementation. Make sure the step function
            # signature does not contain variable *args or **kwargs
            variable_arguments = None
            if step_function_signature.varargs:
                variable_arguments = f"*{step_function_signature.varargs}"
            elif step_function_signature.varkw:
                variable_arguments = f"**{step_function_signature.varkw}"

            if variable_arguments:
                raise StepInterfaceError(
                    f"Unable to create step '{name}' with variable arguments "
                    f"'{variable_arguments}'. Please make sure your step "
                    f"functions are defined with a fixed amount of arguments."
                )

        step_function_args = (
            step_function_signature.args + step_function_signature.kwonlyargs
        )

        # Remove 'self' from the signature if it exists
        if step_function_args and step_function_args[0] == "self":
            step_function_args.pop(0)

        # Verify the input arguments of the step function
        for arg in step_function_args:
            arg_type = step_function_signature.annotations.get(arg, None)
            arg_type = resolve_type_annotation(arg_type)

            if not arg_type:
                raise StepInterfaceError(
                    f"Missing type annotation for argument '{arg}' when "
                    f"trying to create step '{name}'. Please make sure to "
                    f"include type annotations for all your step inputs "
                    f"and outputs."
                )

            if issubclass(arg_type, BaseParameters):
                # Raise an error if we already found a config in the signature
                if cls.PARAMETERS_CLASS is not None:
                    raise StepInterfaceError(
                        f"Found multiple parameter arguments "
                        f"('{cls.PARAMETERS_FUNCTION_PARAMETER_NAME}' and '{arg}') when "
                        f"trying to create step '{name}'. Please make sure to "
                        f"only have one `Parameters` subclass as input "
                        f"argument for a step."
                    )
                cls.PARAMETERS_FUNCTION_PARAMETER_NAME = arg
                cls.PARAMETERS_CLASS = arg_type

            elif issubclass(arg_type, StepContext):
                if cls.CONTEXT_PARAMETER_NAME is not None:
                    raise StepInterfaceError(
                        f"Found multiple context arguments "
                        f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
                        f"trying to create step '{name}'. Please make sure to "
                        f"only have one `StepContext` as input "
                        f"argument for a step."
                    )
                cls.CONTEXT_PARAMETER_NAME = arg
            else:
                # Can't do any check for existing materializers right now
                # as they might get be defined later, so we simply store the
                # argument name and type for later use.
                cls.INPUT_SIGNATURE.update({arg: arg_type})

        # Parse the returns of the step function
        if "return" not in step_function_signature.annotations:
            raise StepInterfaceError(
                "Missing return type annotation when trying to create step "
                f"'{name}'. Please make sure to include type annotations for "
                "all your step inputs and outputs. If your step returns "
                "nothing, please annotate it with `-> None`."
            )
        cls.OUTPUT_SIGNATURE = parse_return_type_annotations(
            step_function_signature.annotations,
        )

        # Raise an exception if input and output names of a step overlap as
        # tfx requires them to be unique
        # TODO [ENG-155]: Can we prefix inputs and outputs to avoid this
        #  restriction?
        counter: Counter[str] = collections.Counter()
        counter.update(list(cls.INPUT_SIGNATURE))
        counter.update(list(cls.OUTPUT_SIGNATURE))
        if cls.PARAMETERS_CLASS:
            counter.update(list(cls.PARAMETERS_CLASS.__fields__.keys()))

        shared_keys = {k for k in counter.elements() if counter[k] > 1}
        if shared_keys:
            raise StepInterfaceError(
                f"The following keys are overlapping in the input, output and "
                f"config parameter names of step '{name}': {shared_keys}. "
                f"Please make sure that your input, output and config "
                f"parameter names are unique."
            )

        return cls
__new__(mcs, name, bases, dct) special staticmethod

Set up a new class with a qualified spec.

Parameters:

Name Type Description Default
name str

The name of the class.

required
bases Tuple[Type[Any], ...]

The base classes of the class.

required
dct Dict[str, Any]

The attributes of the class.

required

Returns:

Type Description
BaseStepMeta

The new class.

Exceptions:

Type Description
StepInterfaceError

When unable to create the step.

Source code in zenml/steps/base_step.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
    """Set up a new class with a qualified spec.

    Args:
        name: The name of the class.
        bases: The base classes of the class.
        dct: The attributes of the class.

    Returns:
        The new class.

    Raises:
        StepInterfaceError: When unable to create the step.
    """
    from zenml.steps.base_parameters import BaseParameters

    dct.setdefault(INSTANCE_CONFIGURATION, {})
    cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))

    cls.INPUT_SIGNATURE = {}
    cls.OUTPUT_SIGNATURE = {}
    cls.PARAMETERS_FUNCTION_PARAMETER_NAME = None
    cls.PARAMETERS_CLASS = None
    cls.CONTEXT_PARAMETER_NAME = None

    # Get the signature of the step function
    step_function_signature = inspect.getfullargspec(
        inspect.unwrap(cls.entrypoint)
    )

    if bases:
        # We're not creating the abstract `BaseStep` class
        # but a concrete implementation. Make sure the step function
        # signature does not contain variable *args or **kwargs
        variable_arguments = None
        if step_function_signature.varargs:
            variable_arguments = f"*{step_function_signature.varargs}"
        elif step_function_signature.varkw:
            variable_arguments = f"**{step_function_signature.varkw}"

        if variable_arguments:
            raise StepInterfaceError(
                f"Unable to create step '{name}' with variable arguments "
                f"'{variable_arguments}'. Please make sure your step "
                f"functions are defined with a fixed amount of arguments."
            )

    step_function_args = (
        step_function_signature.args + step_function_signature.kwonlyargs
    )

    # Remove 'self' from the signature if it exists
    if step_function_args and step_function_args[0] == "self":
        step_function_args.pop(0)

    # Verify the input arguments of the step function
    for arg in step_function_args:
        arg_type = step_function_signature.annotations.get(arg, None)
        arg_type = resolve_type_annotation(arg_type)

        if not arg_type:
            raise StepInterfaceError(
                f"Missing type annotation for argument '{arg}' when "
                f"trying to create step '{name}'. Please make sure to "
                f"include type annotations for all your step inputs "
                f"and outputs."
            )

        if issubclass(arg_type, BaseParameters):
            # Raise an error if we already found a config in the signature
            if cls.PARAMETERS_CLASS is not None:
                raise StepInterfaceError(
                    f"Found multiple parameter arguments "
                    f"('{cls.PARAMETERS_FUNCTION_PARAMETER_NAME}' and '{arg}') when "
                    f"trying to create step '{name}'. Please make sure to "
                    f"only have one `Parameters` subclass as input "
                    f"argument for a step."
                )
            cls.PARAMETERS_FUNCTION_PARAMETER_NAME = arg
            cls.PARAMETERS_CLASS = arg_type

        elif issubclass(arg_type, StepContext):
            if cls.CONTEXT_PARAMETER_NAME is not None:
                raise StepInterfaceError(
                    f"Found multiple context arguments "
                    f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
                    f"trying to create step '{name}'. Please make sure to "
                    f"only have one `StepContext` as input "
                    f"argument for a step."
                )
            cls.CONTEXT_PARAMETER_NAME = arg
        else:
            # Can't do any check for existing materializers right now
            # as they might get be defined later, so we simply store the
            # argument name and type for later use.
            cls.INPUT_SIGNATURE.update({arg: arg_type})

    # Parse the returns of the step function
    if "return" not in step_function_signature.annotations:
        raise StepInterfaceError(
            "Missing return type annotation when trying to create step "
            f"'{name}'. Please make sure to include type annotations for "
            "all your step inputs and outputs. If your step returns "
            "nothing, please annotate it with `-> None`."
        )
    cls.OUTPUT_SIGNATURE = parse_return_type_annotations(
        step_function_signature.annotations,
    )

    # Raise an exception if input and output names of a step overlap as
    # tfx requires them to be unique
    # TODO [ENG-155]: Can we prefix inputs and outputs to avoid this
    #  restriction?
    counter: Counter[str] = collections.Counter()
    counter.update(list(cls.INPUT_SIGNATURE))
    counter.update(list(cls.OUTPUT_SIGNATURE))
    if cls.PARAMETERS_CLASS:
        counter.update(list(cls.PARAMETERS_CLASS.__fields__.keys()))

    shared_keys = {k for k in counter.elements() if counter[k] > 1}
    if shared_keys:
        raise StepInterfaceError(
            f"The following keys are overlapping in the input, output and "
            f"config parameter names of step '{name}': {shared_keys}. "
            f"Please make sure that your input, output and config "
            f"parameter names are unique."
        )

    return cls

step_context

Step context class.

StepContext

Provides additional context inside a step function.

This class is used to access pipelines, materializers, and artifacts inside a step function. To use it, add a StepContext object to the signature of your step function like this:

@step
def my_step(context: StepContext, ...)
    context.get_output_materializer(...)

You do not need to create a StepContext object yourself and pass it when creating the step, as long as you specify it in the signature ZenML will create the StepContext and automatically pass it when executing your step.

Note: When using a StepContext inside a step, ZenML disables caching for this step by default as the context provides access to external resources which might influence the result of your step execution. To enable caching anyway, explicitly enable it in the @step decorator or when initializing your custom step class.

Source code in zenml/steps/step_context.py
class StepContext:
    """Provides additional context inside a step function.

    This class is used to access pipelines, materializers, and artifacts
    inside a step function. To use it, add a `StepContext` object
    to the signature of your step function like this:

    ```python
    @step
    def my_step(context: StepContext, ...)
        context.get_output_materializer(...)
    ```

    You do not need to create a `StepContext` object yourself and pass it
    when creating the step, as long as you specify it in the signature ZenML
    will create the `StepContext` and automatically pass it when executing your
    step.

    **Note**: When using a `StepContext` inside a step, ZenML disables caching
    for this step by default as the context provides access to external
    resources which might influence the result of your step execution. To
    enable caching anyway, explicitly enable it in the `@step` decorator or when
    initializing your custom step class.
    """

    def __init__(
        self,
        step_name: str,
        output_materializers: Dict[str, Type["BaseMaterializer"]],
        output_artifacts: Dict[str, "BaseArtifact"],
    ):
        """Initializes a StepContext instance.

        Args:
            step_name: The name of the step that this context is used in.
            output_materializers: The output materializers of the step that
                                  this context is used in.
            output_artifacts: The output artifacts of the step that this
                              context is used in.

        Raises:
             StepContextError: If the keys of the output materializers and
                               output artifacts do not match.
        """
        if output_materializers.keys() != output_artifacts.keys():
            raise StepContextError(
                f"Mismatched keys in output materializers and output "
                f"artifacts for step '{step_name}'. Output materializer "
                f"keys: {set(output_materializers)}, output artifact "
                f"keys: {set(output_artifacts)}"
            )

        self.step_name = step_name
        self._outputs = {
            key: StepContextOutput(
                output_materializers[key], output_artifacts[key]
            )
            for key in output_materializers.keys()
        }
        self._stack = Client().active_stack

    def _get_output(
        self, output_name: Optional[str] = None
    ) -> StepContextOutput:
        """Returns the materializer and artifact URI for a given step output.

        Args:
            output_name: Optional name of the output for which to get the
                materializer and URI.

        Returns:
            Tuple containing the materializer and artifact URI for the
            given output.

        Raises:
            StepContextError: If the step has no outputs, no output for
                              the given `output_name` or if no `output_name`
                              was given but the step has multiple outputs.
        """
        output_count = len(self._outputs)
        if output_count == 0:
            raise StepContextError(
                f"Unable to get step output for step '{self.step_name}': "
                f"This step does not have any outputs."
            )

        if not output_name and output_count > 1:
            raise StepContextError(
                f"Unable to get step output for step '{self.step_name}': "
                f"This step has multiple outputs ({set(self._outputs)}), "
                f"please specify which output to return."
            )

        if output_name:
            if output_name not in self._outputs:
                raise StepContextError(
                    f"Unable to get step output '{output_name}' for "
                    f"step '{self.step_name}'. This step does not have an "
                    f"output with the given name, please specify one of the "
                    f"available outputs: {set(self._outputs)}."
                )
            return self._outputs[output_name]
        else:
            return next(iter(self._outputs.values()))

    @property
    def stack(self) -> Optional["Stack"]:
        """Returns the current active stack.

        Returns:
            The current active stack or None.
        """
        return self._stack

    def get_output_materializer(
        self,
        output_name: Optional[str] = None,
        custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
    ) -> "BaseMaterializer":
        """Returns a materializer for a given step output.

        Args:
            output_name: Optional name of the output for which to get the
                materializer. If no name is given and the step only has a
                single output, the materializer of this output will be
                returned. If the step has multiple outputs, an exception
                will be raised.
            custom_materializer_class: If given, this `BaseMaterializer`
                subclass will be initialized with the output artifact instead
                of the materializer that was registered for this step output.

        Returns:
            A materializer initialized with the output artifact for
            the given output.
        """
        materializer_class, artifact = self._get_output(output_name)
        # use custom materializer class if provided or fallback to default
        # materializer for output
        materializer_class = custom_materializer_class or materializer_class
        return materializer_class(artifact)

    def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
        """Returns the artifact URI for a given step output.

        Args:
            output_name: Optional name of the output for which to get the URI.
                If no name is given and the step only has a single output,
                the URI of this output will be returned. If the step has
                multiple outputs, an exception will be raised.

        Returns:
            Artifact URI for the given output.
        """
        return cast(str, self._get_output(output_name).artifact.uri)
stack: Optional[Stack] property readonly

Returns the current active stack.

Returns:

Type Description
Optional[Stack]

The current active stack or None.

__init__(self, step_name, output_materializers, output_artifacts) special

Initializes a StepContext instance.

Parameters:

Name Type Description Default
step_name str

The name of the step that this context is used in.

required
output_materializers Dict[str, Type[BaseMaterializer]]

The output materializers of the step that this context is used in.

required
output_artifacts Dict[str, BaseArtifact]

The output artifacts of the step that this context is used in.

required

Exceptions:

Type Description
StepContextError

If the keys of the output materializers and output artifacts do not match.

Source code in zenml/steps/step_context.py
def __init__(
    self,
    step_name: str,
    output_materializers: Dict[str, Type["BaseMaterializer"]],
    output_artifacts: Dict[str, "BaseArtifact"],
):
    """Initializes a StepContext instance.

    Args:
        step_name: The name of the step that this context is used in.
        output_materializers: The output materializers of the step that
                              this context is used in.
        output_artifacts: The output artifacts of the step that this
                          context is used in.

    Raises:
         StepContextError: If the keys of the output materializers and
                           output artifacts do not match.
    """
    if output_materializers.keys() != output_artifacts.keys():
        raise StepContextError(
            f"Mismatched keys in output materializers and output "
            f"artifacts for step '{step_name}'. Output materializer "
            f"keys: {set(output_materializers)}, output artifact "
            f"keys: {set(output_artifacts)}"
        )

    self.step_name = step_name
    self._outputs = {
        key: StepContextOutput(
            output_materializers[key], output_artifacts[key]
        )
        for key in output_materializers.keys()
    }
    self._stack = Client().active_stack
get_output_artifact_uri(self, output_name=None)

Returns the artifact URI for a given step output.

Parameters:

Name Type Description Default
output_name Optional[str]

Optional name of the output for which to get the URI. If no name is given and the step only has a single output, the URI of this output will be returned. If the step has multiple outputs, an exception will be raised.

None

Returns:

Type Description
str

Artifact URI for the given output.

Source code in zenml/steps/step_context.py
def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
    """Returns the artifact URI for a given step output.

    Args:
        output_name: Optional name of the output for which to get the URI.
            If no name is given and the step only has a single output,
            the URI of this output will be returned. If the step has
            multiple outputs, an exception will be raised.

    Returns:
        Artifact URI for the given output.
    """
    return cast(str, self._get_output(output_name).artifact.uri)
get_output_materializer(self, output_name=None, custom_materializer_class=None)

Returns a materializer for a given step output.

Parameters:

Name Type Description Default
output_name Optional[str]

Optional name of the output for which to get the materializer. If no name is given and the step only has a single output, the materializer of this output will be returned. If the step has multiple outputs, an exception will be raised.

None
custom_materializer_class Optional[Type[BaseMaterializer]]

If given, this BaseMaterializer subclass will be initialized with the output artifact instead of the materializer that was registered for this step output.

None

Returns:

Type Description
BaseMaterializer

A materializer initialized with the output artifact for the given output.

Source code in zenml/steps/step_context.py
def get_output_materializer(
    self,
    output_name: Optional[str] = None,
    custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
    """Returns a materializer for a given step output.

    Args:
        output_name: Optional name of the output for which to get the
            materializer. If no name is given and the step only has a
            single output, the materializer of this output will be
            returned. If the step has multiple outputs, an exception
            will be raised.
        custom_materializer_class: If given, this `BaseMaterializer`
            subclass will be initialized with the output artifact instead
            of the materializer that was registered for this step output.

    Returns:
        A materializer initialized with the output artifact for
        the given output.
    """
    materializer_class, artifact = self._get_output(output_name)
    # use custom materializer class if provided or fallback to default
    # materializer for output
    materializer_class = custom_materializer_class or materializer_class
    return materializer_class(artifact)

StepContextOutput (tuple)

Tuple containing materializer class and artifact for a step output.

Source code in zenml/steps/step_context.py
class StepContextOutput(NamedTuple):
    """Tuple containing materializer class and artifact for a step output."""

    materializer_class: Type["BaseMaterializer"]
    artifact: "BaseArtifact"
__getnewargs__(self) special

Return self as a plain tuple. Used by copy and pickle.

Source code in zenml/steps/step_context.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)
__new__(_cls, materializer_class, artifact) special staticmethod

Create new instance of StepContextOutput(materializer_class, artifact)

__repr__(self) special

Return a nicely formatted representation string

Source code in zenml/steps/step_context.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

step_decorator

Step decorator function.

step(_func=None, *, name=None, enable_cache=None, experiment_tracker=None, step_operator=None, output_artifacts=None, output_materializers=None, settings=None, extra=None)

Outer decorator function for the creation of a ZenML step.

In order to be able to work with parameters such as name, it features a nested decorator structure.

Parameters:

Name Type Description Default
_func Optional[~F]

The decorated function.

None
name Optional[str]

The name of the step. If left empty, the name of the decorated function will be used as a fallback.

None
enable_cache Optional[bool]

Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default unless the step requires a StepContext (see zenml.steps.step_context.StepContext for more information).

None
experiment_tracker Optional[str]

The experiment tracker to use for this step.

None
step_operator Optional[str]

The step operator to use for this step.

None
output_materializers Optional[OutputMaterializersSpecification]

Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs.

None
output_artifacts Optional[OutputArtifactsSpecification]

Output artifacts for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the artifact class will be used for all outputs.

None
settings Optional[Dict[str, SettingsOrDict]]

Settings for this step.

None
extra Optional[Dict[str, Any]]

Extra configurations for this step.

None

Returns:

Type Description
Union[Type[zenml.steps.base_step.BaseStep], Callable[[~F], Type[zenml.steps.base_step.BaseStep]]]

the inner decorator which creates the step class based on the ZenML BaseStep

Source code in zenml/steps/step_decorator.py
def step(
    _func: Optional[F] = None,
    *,
    name: Optional[str] = None,
    enable_cache: Optional[bool] = None,
    experiment_tracker: Optional[str] = None,
    step_operator: Optional[str] = None,
    output_artifacts: Optional["OutputArtifactsSpecification"] = None,
    output_materializers: Optional["OutputMaterializersSpecification"] = None,
    settings: Optional[Dict[str, "SettingsOrDict"]] = None,
    extra: Optional[Dict[str, Any]] = None,
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
    """Outer decorator function for the creation of a ZenML step.

    In order to be able to work with parameters such as `name`, it features a
    nested decorator structure.

    Args:
        _func: The decorated function.
        name: The name of the step. If left empty, the name of the decorated
            function will be used as a fallback.
        enable_cache: Specify whether caching is enabled for this step. If no
            value is passed, caching is enabled by default unless the step
            requires a `StepContext` (see
            `zenml.steps.step_context.StepContext` for more information).
        experiment_tracker: The experiment tracker to use for this step.
        step_operator: The step operator to use for this step.
        output_materializers: Output materializers for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            materializer will be used for all outputs.
        output_artifacts: Output artifacts for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            artifact class will be used for all outputs.
        settings: Settings for this step.
        extra: Extra configurations for this step.

    Returns:
        the inner decorator which creates the step class based on the
        ZenML BaseStep
    """

    def inner_decorator(func: F) -> Type[BaseStep]:
        """Inner decorator function for the creation of a ZenML Step.

        Args:
            func: types.FunctionType, this function will be used as the
                "process" method of the generated Step.

        Returns:
            The class of a newly generated ZenML Step.
        """
        return type(  # noqa
            func.__name__,
            (BaseStep,),
            {
                STEP_INNER_FUNC_NAME: staticmethod(func),
                INSTANCE_CONFIGURATION: {
                    PARAM_STEP_NAME: name,
                    PARAM_CREATED_BY_FUNCTIONAL_API: True,
                    PARAM_ENABLE_CACHE: enable_cache,
                    PARAM_EXPERIMENT_TRACKER: experiment_tracker,
                    PARAM_STEP_OPERATOR: step_operator,
                    PARAM_OUTPUT_ARTIFACTS: output_artifacts,
                    PARAM_OUTPUT_MATERIALIZERS: output_materializers,
                    PARAM_SETTINGS: settings,
                    PARAM_EXTRA_OPTIONS: extra,
                },
                "__module__": func.__module__,
                "__doc__": func.__doc__,
            },
        )

    if _func is None:
        return inner_decorator
    else:
        return inner_decorator(_func)

step_environment

Step environment class.

StepEnvironment (BaseEnvironmentComponent)

Added information about a step runtime inside a step function.

This takes the form of an Environment component. This class can be used from within a pipeline step implementation to access additional information about the runtime parameters of a pipeline step, such as the pipeline name, pipeline run ID and other pipeline runtime information. To use it, access it inside your step function like this:

from zenml.environment import Environment

@step
def my_step(...)
    env = Environment().step_environment
    do_something_with(env.pipeline_name, env.pipeline_run_id, env.step_name)
Source code in zenml/steps/step_environment.py
class StepEnvironment(BaseEnvironmentComponent):
    """Added information about a step runtime inside a step function.

    This takes the form of an Environment component. This class can be used from
    within a pipeline step implementation to access additional information about
    the runtime parameters of a pipeline step, such as the pipeline name,
    pipeline run ID and other pipeline runtime information. To use it, access it
    inside your step function like this:

    ```python
    from zenml.environment import Environment

    @step
    def my_step(...)
        env = Environment().step_environment
        do_something_with(env.pipeline_name, env.pipeline_run_id, env.step_name)
    ```
    """

    NAME = STEP_ENVIRONMENT_NAME

    def __init__(
        self,
        pipeline_name: str,
        pipeline_run_id: str,
        step_name: str,
        cache_enabled: bool,
        step_run_info: "StepRunInfo",
    ):
        """Initialize the environment of the currently running step.

        Args:
            pipeline_name: the name of the currently running pipeline
            pipeline_run_id: the ID of the currently running pipeline
            step_name: the name of the currently running step
            cache_enabled: whether cache is enabled for this step
            step_run_info: Info about the currently running step.
        """
        super().__init__()
        self._pipeline_name = pipeline_name
        self._pipeline_run_id = pipeline_run_id
        self._step_name = step_name
        self._step_run_info = step_run_info
        self._cache_enabled = cache_enabled

    @property
    def pipeline_name(self) -> str:
        """The name of the currently running pipeline.

        Returns:
            The name of the currently running pipeline.
        """
        return self._pipeline_name

    @property
    def pipeline_run_id(self) -> str:
        """The ID of the current pipeline run.

        Returns:
            The ID of the current pipeline run.
        """
        return self._pipeline_run_id

    @property
    def step_name(self) -> str:
        """The name of the currently running step.

        Returns:
            The name of the currently running step.
        """
        return self._step_name

    @property
    def step_run_info(self) -> "StepRunInfo":
        """Info about the currently running step.

        Returns:
            Info about the currently running step.
        """
        return self._step_run_info

    @property
    def cache_enabled(self) -> bool:
        """Returns whether cache is enabled for the step.

        Returns:
            True if cache is enabled for the step, otherwise False.
        """
        return self._cache_enabled
cache_enabled: bool property readonly

Returns whether cache is enabled for the step.

Returns:

Type Description
bool

True if cache is enabled for the step, otherwise False.

pipeline_name: str property readonly

The name of the currently running pipeline.

Returns:

Type Description
str

The name of the currently running pipeline.

pipeline_run_id: str property readonly

The ID of the current pipeline run.

Returns:

Type Description
str

The ID of the current pipeline run.

step_name: str property readonly

The name of the currently running step.

Returns:

Type Description
str

The name of the currently running step.

step_run_info: StepRunInfo property readonly

Info about the currently running step.

Returns:

Type Description
StepRunInfo

Info about the currently running step.

__init__(self, pipeline_name, pipeline_run_id, step_name, cache_enabled, step_run_info) special

Initialize the environment of the currently running step.

Parameters:

Name Type Description Default
pipeline_name str

the name of the currently running pipeline

required
pipeline_run_id str

the ID of the currently running pipeline

required
step_name str

the name of the currently running step

required
cache_enabled bool

whether cache is enabled for this step

required
step_run_info StepRunInfo

Info about the currently running step.

required
Source code in zenml/steps/step_environment.py
def __init__(
    self,
    pipeline_name: str,
    pipeline_run_id: str,
    step_name: str,
    cache_enabled: bool,
    step_run_info: "StepRunInfo",
):
    """Initialize the environment of the currently running step.

    Args:
        pipeline_name: the name of the currently running pipeline
        pipeline_run_id: the ID of the currently running pipeline
        step_name: the name of the currently running step
        cache_enabled: whether cache is enabled for this step
        step_run_info: Info about the currently running step.
    """
    super().__init__()
    self._pipeline_name = pipeline_name
    self._pipeline_run_id = pipeline_run_id
    self._step_name = step_name
    self._step_run_info = step_run_info
    self._cache_enabled = cache_enabled

step_interfaces special

Initialization for step interfaces.

base_alerter_step

Base alerter step.

BaseAlerterStep (BaseStep)

Send a message to the configured chat service.

Source code in zenml/steps/step_interfaces/base_alerter_step.py
class BaseAlerterStep(BaseStep):
    """Send a message to the configured chat service."""

    @abstractmethod
    def entrypoint(  # type: ignore[override]
        self,
        message: str,
        params: BaseAlerterStepParameters,
        context: StepContext,
    ) -> bool:
        """Entrypoint for an Alerter step.

        Args:
            message: The message to send.
            params: The parameters for the step.
            context: The context for the step.

        Returns:
            True if the message was sent successfully.
        """
PARAMETERS_CLASS (BaseParameters) pydantic-model

Step parameters definition for all alerters.

Source code in zenml/steps/step_interfaces/base_alerter_step.py
class BaseAlerterStepParameters(BaseParameters):
    """Step parameters definition for all alerters."""
entrypoint(self, message, params, context)

Entrypoint for an Alerter step.

Parameters:

Name Type Description Default
message str

The message to send.

required
params BaseAlerterStepParameters

The parameters for the step.

required
context StepContext

The context for the step.

required

Returns:

Type Description
bool

True if the message was sent successfully.

Source code in zenml/steps/step_interfaces/base_alerter_step.py
@abstractmethod
def entrypoint(  # type: ignore[override]
    self,
    message: str,
    params: BaseAlerterStepParameters,
    context: StepContext,
) -> bool:
    """Entrypoint for an Alerter step.

    Args:
        message: The message to send.
        params: The parameters for the step.
        context: The context for the step.

    Returns:
        True if the message was sent successfully.
    """
BaseAlerterStepParameters (BaseParameters) pydantic-model

Step parameters definition for all alerters.

Source code in zenml/steps/step_interfaces/base_alerter_step.py
class BaseAlerterStepParameters(BaseParameters):
    """Step parameters definition for all alerters."""

base_analyzer_step

Base analyzer step.

BaseAnalyzerParameters (BaseParameters) pydantic-model

Base class for analyzer step parameters.

Source code in zenml/steps/step_interfaces/base_analyzer_step.py
class BaseAnalyzerParameters(BaseParameters):
    """Base class for analyzer step parameters."""
BaseAnalyzerStep (BaseStep)

Base step implementation for any analyzer step implementation.

Source code in zenml/steps/step_interfaces/base_analyzer_step.py
class BaseAnalyzerStep(BaseStep):
    """Base step implementation for any analyzer step implementation."""

    @abstractmethod
    def entrypoint(  # type: ignore[override]
        self,
        dataset: DataArtifact,
        params: BaseAnalyzerParameters,
        context: StepContext,
    ) -> Output(  # type:ignore[valid-type]
        statistics=StatisticsArtifact, schema=SchemaArtifact
    ):
        """Base entrypoint for any analyzer implementation.

        Args:
            dataset: The dataset to analyze.
            params: The parameters for the step.
            context: The context for the step.

        Returns:
            The statistics and the schema of the given dataset.
        """
PARAMETERS_CLASS (BaseParameters) pydantic-model

Base class for analyzer step parameters.

Source code in zenml/steps/step_interfaces/base_analyzer_step.py
class BaseAnalyzerParameters(BaseParameters):
    """Base class for analyzer step parameters."""
entrypoint(self, dataset, params, context)

Base entrypoint for any analyzer implementation.

Parameters:

Name Type Description Default
dataset DataArtifact

The dataset to analyze.

required
params BaseAnalyzerParameters

The parameters for the step.

required
context StepContext

The context for the step.

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7fb68df63760>

The statistics and the schema of the given dataset.

Source code in zenml/steps/step_interfaces/base_analyzer_step.py
@abstractmethod
def entrypoint(  # type: ignore[override]
    self,
    dataset: DataArtifact,
    params: BaseAnalyzerParameters,
    context: StepContext,
) -> Output(  # type:ignore[valid-type]
    statistics=StatisticsArtifact, schema=SchemaArtifact
):
    """Base entrypoint for any analyzer implementation.

    Args:
        dataset: The dataset to analyze.
        params: The parameters for the step.
        context: The context for the step.

    Returns:
        The statistics and the schema of the given dataset.
    """

base_datasource_step

Base datasource step.

BaseDatasourceParameters (BaseParameters) pydantic-model

Base class for datasource parameters to inherit from.

Source code in zenml/steps/step_interfaces/base_datasource_step.py
class BaseDatasourceParameters(BaseParameters):
    """Base class for datasource parameters to inherit from."""
BaseDatasourceStep (BaseStep)

Base step implementation for any datasource step implementation.

Source code in zenml/steps/step_interfaces/base_datasource_step.py
class BaseDatasourceStep(BaseStep):
    """Base step implementation for any datasource step implementation."""

    @abstractmethod
    def entrypoint(  # type: ignore[override]
        self,
        params: BaseDatasourceParameters,
        context: StepContext,
    ) -> DataArtifact:
        """Base entrypoint for any datasource implementation.

        Args:
            params: The parameters for the step.
            context: The context for the step.

        Returns:
            The dataset.
        """
PARAMETERS_CLASS (BaseParameters) pydantic-model

Base class for datasource parameters to inherit from.

Source code in zenml/steps/step_interfaces/base_datasource_step.py
class BaseDatasourceParameters(BaseParameters):
    """Base class for datasource parameters to inherit from."""
entrypoint(self, params, context)

Base entrypoint for any datasource implementation.

Parameters:

Name Type Description Default
params BaseDatasourceParameters

The parameters for the step.

required
context StepContext

The context for the step.

required

Returns:

Type Description
DataArtifact

The dataset.

Source code in zenml/steps/step_interfaces/base_datasource_step.py
@abstractmethod
def entrypoint(  # type: ignore[override]
    self,
    params: BaseDatasourceParameters,
    context: StepContext,
) -> DataArtifact:
    """Base entrypoint for any datasource implementation.

    Args:
        params: The parameters for the step.
        context: The context for the step.

    Returns:
        The dataset.
    """

base_drift_detection_step

Base drift detection step.

BaseDriftDetectionParameters (BaseParameters) pydantic-model

Base class for drift detection step parameters.

Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
class BaseDriftDetectionParameters(BaseParameters):
    """Base class for drift detection step parameters."""
BaseDriftDetectionStep (BaseStep)

Base step implementation for any drift detection step implementation.

Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
class BaseDriftDetectionStep(BaseStep):
    """Base step implementation for any drift detection step implementation."""

    @abstractmethod
    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: DataArtifact,
        comparison_dataset: DataArtifact,
        params: BaseDriftDetectionParameters,
        context: StepContext,
    ) -> Any:
        """Base entrypoint for any drift detection implementation.

        Args:
            reference_dataset: The reference dataset.
            comparison_dataset: The comparison dataset.
            params: The parameters for the step.
            context: The context for the step.

        Returns:
            The result of the drift detection.
        """
PARAMETERS_CLASS (BaseParameters) pydantic-model

Base class for drift detection step parameters.

Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
class BaseDriftDetectionParameters(BaseParameters):
    """Base class for drift detection step parameters."""
entrypoint(self, reference_dataset, comparison_dataset, params, context)

Base entrypoint for any drift detection implementation.

Parameters:

Name Type Description Default
reference_dataset DataArtifact

The reference dataset.

required
comparison_dataset DataArtifact

The comparison dataset.

required
params BaseDriftDetectionParameters

The parameters for the step.

required
context StepContext

The context for the step.

required

Returns:

Type Description
Any

The result of the drift detection.

Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
@abstractmethod
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: DataArtifact,
    comparison_dataset: DataArtifact,
    params: BaseDriftDetectionParameters,
    context: StepContext,
) -> Any:
    """Base entrypoint for any drift detection implementation.

    Args:
        reference_dataset: The reference dataset.
        comparison_dataset: The comparison dataset.
        params: The parameters for the step.
        context: The context for the step.

    Returns:
        The result of the drift detection.
    """

step_output

Step output class.

Output

A named tuple with a default name that cannot be overridden.

Source code in zenml/steps/step_output.py
class Output(object):
    """A named tuple with a default name that cannot be overridden."""

    def __init__(self, **kwargs: Type[Any]):
        """Initializes the output.

        Args:
            **kwargs: The output values.
        """
        # TODO [ENG-161]: do we even need the named tuple here or is
        #  a list of tuples (name, Type) sufficient?
        self.outputs = NamedTuple("ZenOutput", **kwargs)  # type: ignore[misc]

    def items(self) -> Iterator[Tuple[str, Type[Any]]]:
        """Yields a tuple of type (output_name, output_type).

        Yields:
            A tuple of type (output_name, output_type).
        """
        yield from self.outputs.__annotations__.items()
__init__(self, **kwargs) special

Initializes the output.

Parameters:

Name Type Description Default
**kwargs Type[Any]

The output values.

{}
Source code in zenml/steps/step_output.py
def __init__(self, **kwargs: Type[Any]):
    """Initializes the output.

    Args:
        **kwargs: The output values.
    """
    # TODO [ENG-161]: do we even need the named tuple here or is
    #  a list of tuples (name, Type) sufficient?
    self.outputs = NamedTuple("ZenOutput", **kwargs)  # type: ignore[misc]
items(self)

Yields a tuple of type (output_name, output_type).

Yields:

Type Description
Iterator[Tuple[str, Type[Any]]]

A tuple of type (output_name, output_type).

Source code in zenml/steps/step_output.py
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
    """Yields a tuple of type (output_name, output_type).

    Yields:
        A tuple of type (output_name, output_type).
    """
    yield from self.outputs.__annotations__.items()

utils

Utility functions for Steps.

The collection of utility functions/classes are inspired by their original implementation of the Tensorflow Extended team, which can be found here:

https://github.com/tensorflow/tfx/blob/master/tfx/dsl/component/experimental /decorators.py

This version is heavily adjusted to work with the Pipeline-Step paradigm which is proposed by ZenML.

create_component_class(step)

Creates a TFX component class.

Parameters:

Name Type Description Default
step BaseStep

The step for which to create the component class.

required

Returns:

Type Description
Type[_SimpleComponent]

The component class.

Source code in zenml/steps/utils.py
def create_component_class(step: "BaseStep") -> Type["_SimpleComponent"]:
    """Creates a TFX component class.

    Args:
        step: The step for which to create the component class.

    Returns:
        The component class.
    """
    from tfx.dsl.component.experimental.decorators import _SimpleComponent

    executor_class = create_executor_class(step=step)
    component_spec_class = _create_component_spec_class(step=step)

    module, _ = source_utils.resolve_class(step.__class__).rsplit(".", 1)

    return type(
        step.configuration.name,
        (_SimpleComponent,),
        {
            "SPEC_CLASS": component_spec_class,
            "EXECUTOR_SPEC": ExecutorClassSpec(executor_class=executor_class),
            "__module__": module,
        },
    )

create_executor_class(step)

Creates an executor class for a step.

Parameters:

Name Type Description Default
step BaseStep

The step instance for which to create an executor class.

required

Returns:

Type Description
Type[_ZenMLStepExecutor]

The executor class.

Source code in zenml/steps/utils.py
def create_executor_class(
    step: "BaseStep",
) -> Type["_ZenMLStepExecutor"]:
    """Creates an executor class for a step.

    Args:
        step: The step instance for which to create an executor class.

    Returns:
        The executor class.
    """
    executor_class_name = _get_executor_class_name(step.configuration.name)
    executor_class = type(
        executor_class_name,
        (_ZenMLStepExecutor,),
        {"_STEP": step, "__module__": __name__},
    )

    # Add the executor class to the current module, so tfx can load it
    module = sys.modules[__name__]
    setattr(module, executor_class_name, executor_class)

    return executor_class

get_executor_class(step_name)

Gets the executor class for a step.

Parameters:

Name Type Description Default
step_name str

Name of the step for which to get the executor class.

required

Returns:

Type Description
Optional[Type[_ZenMLStepExecutor]]

The executor class.

Source code in zenml/steps/utils.py
def get_executor_class(step_name: str) -> Optional[Type["_ZenMLStepExecutor"]]:
    """Gets the executor class for a step.

    Args:
        step_name: Name of the step for which to get the executor class.

    Returns:
        The executor class.
    """
    executor_class_name = _get_executor_class_name(step_name)
    module = sys.modules[__name__]
    return getattr(module, executor_class_name, None)

parse_return_type_annotations(step_annotations)

Parse the returns of a step function into a dict of resolved types.

Called within BaseStepMeta.__new__() to define cls.OUTPUT_SIGNATURE. Called within Do() to resolve type annotations.

Parameters:

Name Type Description Default
step_annotations Dict[str, Any]

Type annotations of the step function.

required

Returns:

Type Description
Dict[str, Any]

Output signature of the new step class.

Source code in zenml/steps/utils.py
def parse_return_type_annotations(
    step_annotations: Dict[str, Any]
) -> Dict[str, Any]:
    """Parse the returns of a step function into a dict of resolved types.

    Called within `BaseStepMeta.__new__()` to define `cls.OUTPUT_SIGNATURE`.
    Called within `Do()` to resolve type annotations.

    Args:
        step_annotations: Type annotations of the step function.

    Returns:
        Output signature of the new step class.
    """
    return_type = step_annotations.get("return", None)
    if return_type is None:
        return {}

    # Cast simple output types to `Output`.
    if not isinstance(return_type, Output):
        return_type = Output(**{SINGLE_RETURN_OUT_NAME: return_type})

    # Resolve type annotations of all outputs and save in new dict.
    output_signature = {
        output_name: resolve_type_annotation(output_type)
        for output_name, output_type in return_type.items()
    }
    return output_signature

resolve_type_annotation(obj)

Returns the non-generic class for generic aliases of the typing module.

If the input is no generic typing alias, the input itself is returned.

Example: if the input object is typing.Dict, this method will return the concrete class dict.

Parameters:

Name Type Description Default
obj Any

The object to resolve.

required

Returns:

Type Description
Any

The non-generic class for generic aliases of the typing module.

Source code in zenml/steps/utils.py
def resolve_type_annotation(obj: Any) -> Any:
    """Returns the non-generic class for generic aliases of the typing module.

    If the input is no generic typing alias, the input itself is returned.

    Example: if the input object is `typing.Dict`, this method will return the
    concrete class `dict`.

    Args:
        obj: The object to resolve.

    Returns:
        The non-generic class for generic aliases of the typing module.
    """
    from typing import _GenericAlias  # type: ignore[attr-defined]

    if sys.version_info >= (3, 8):
        return typing.get_origin(obj) or obj
    else:
        # python 3.7
        if isinstance(obj, _GenericAlias):
            return obj.__origin__
        else:
            return obj