Steps
zenml.steps
special
Initializer for ZenML steps.
A step is a single piece or stage of a ZenML pipeline. Think of each step as being one of the nodes of a Directed Acyclic Graph (or DAG). Steps are responsible for one aspect of processing or interacting with the data / artifacts in the pipeline.
Conceptually, a Step is a discrete and independent part of a pipeline that is responsible for one particular aspect of data manipulation inside a ZenML pipeline.
Steps can be subclassed from the BaseStep
class, or used via our @step
decorator.
base_parameters
Step parameters.
BaseParameters (BaseModel)
pydantic-model
Base class to pass parameters into a step.
Source code in zenml/steps/base_parameters.py
class BaseParameters(BaseModel):
"""Base class to pass parameters into a step."""
base_step
Base Step for ZenML.
BaseStep
Abstract base class for all ZenML steps.
Attributes:
Name | Type | Description |
---|---|---|
name |
The name of this step. |
|
pipeline_parameter_name |
The name of the pipeline parameter for which this step was passed as an argument. |
|
enable_cache |
A boolean indicating if caching is enabled for this step. |
|
enable_artifact_metadata |
A boolean indicating if artifact metadata is enabled for this step. |
|
enable_artifact_visualization |
A boolean indicating if artifact visualization is enabled for this step. |
Source code in zenml/steps/base_step.py
class BaseStep(metaclass=BaseStepMeta):
"""Abstract base class for all ZenML steps.
Attributes:
name: The name of this step.
pipeline_parameter_name: The name of the pipeline parameter for which
this step was passed as an argument.
enable_cache: A boolean indicating if caching is enabled for this step.
enable_artifact_metadata: A boolean indicating if artifact metadata
is enabled for this step.
enable_artifact_visualization: A boolean indicating if artifact
visualization is enabled for this step.
"""
INPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
PARAMETERS_FUNCTION_PARAMETER_NAME: ClassVar[Optional[str]] = None
PARAMETERS_CLASS: ClassVar[Optional[Type["BaseParameters"]]] = None
CONTEXT_PARAMETER_NAME: ClassVar[Optional[str]] = None
INSTANCE_CONFIGURATION: Dict[str, Any] = {}
class _OutputArtifact(NamedTuple):
"""Internal step output artifact.
This class is used for inputs/outputs of the __call__ method of
BaseStep. It passes all the information about step outputs so downstream
steps can finalize their configuration.
Attributes:
name: Name of the output.
step_name: Name of the step that produced this output.
materializer_source: The source of the materializer used to
write the output.
"""
name: str
step_name: str
materializer_source: Source
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes a step.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
"""
self._has_been_called = False
self._upstream_steps: Set[str] = set()
self._inputs: Dict[str, InputSpec] = {}
kwargs = {**self.INSTANCE_CONFIGURATION, **kwargs}
name = kwargs.pop(PARAM_STEP_NAME, None) or self.__class__.__name__
# This value is only used in `BaseStep.__created_by_functional_api()`
kwargs.pop(PARAM_CREATED_BY_FUNCTIONAL_API, None)
requires_context = bool(self.CONTEXT_PARAMETER_NAME)
enable_cache = kwargs.pop(PARAM_ENABLE_CACHE, None)
if enable_cache is None:
if requires_context:
# Using the StepContext inside a step provides access to
# external resources which might influence the step execution.
# We therefore disable caching unless it is explicitly enabled
enable_cache = False
logger.debug(
"Step '%s': Step context required and caching not "
"explicitly enabled.",
name,
)
logger.debug(
"Step '%s': Caching %s.",
name,
"enabled" if enable_cache is not False else "disabled",
)
enable_artifact_metadata = kwargs.pop(
PARAM_ENABLE_ARTIFACT_METADATA, None
)
logger.debug(
"Step '%s': Artifact metadata %s.",
name,
"enabled" if enable_artifact_metadata is not False else "disabled",
)
enable_artifact_visualization = kwargs.pop(
PARAM_ENABLE_ARTIFACT_VISUALIZATION, None
)
logger.debug(
"Step '%s': Artifact visualization %s.",
name,
"enabled"
if enable_artifact_visualization is not False
else "disabled",
)
self._configuration = PartialStepConfiguration(
name=name,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
)
self._apply_class_configuration(kwargs)
self._verify_and_apply_init_params(*args, **kwargs)
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
Returns:
The output of the step.
"""
@classmethod
def _created_by_functional_api(cls) -> bool:
"""Returns if the step class was created by the functional API.
Returns:
`True` if the class was created by the functional API,
`False` otherwise.
"""
return cls.INSTANCE_CONFIGURATION.get(
PARAM_CREATED_BY_FUNCTIONAL_API, False
)
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
"""Loads a step from source.
Args:
source: The path to the step source.
Returns:
The loaded step.
"""
step_class: Type[BaseStep] = source_utils.load_and_validate_class(
source, expected_class=BaseStep
)
return step_class()
@property
def upstream_steps(self) -> Set[str]:
"""Names of the upstream steps of this step.
This property will only contain the full set of upstream steps once
it's parent pipeline `connect(...)` method was called.
Returns:
Set of upstream step names.
"""
return self._upstream_steps
def after(self, step: "BaseStep") -> None:
"""Adds an upstream step to this step.
Calling this method makes sure this step only starts running once the
given step has successfully finished executing.
**Note**: This can only be called inside the pipeline connect function
which is decorated with the `@pipeline` decorator. Any calls outside
this function will be ignored.
Example:
The following pipeline will run its steps sequentially in the following
order: step_2 -> step_1 -> step_3
```python
@pipeline
def example_pipeline(step_1, step_2, step_3):
step_1.after(step_2)
step_3(step_1(), step_2())
```
Args:
step: A step which should finish executing before this step is
started.
"""
self._upstream_steps.add(step.name)
@property
def inputs(self) -> Dict[str, InputSpec]:
"""Step input specifications.
This depends on the upstream steps in a pipeline and can therefore
only be accessed once the step has been called in a pipeline.
Raises:
RuntimeError: If this property is accessed before the step was
called in a pipeline.
Returns:
The step input specifications.
"""
if not self._has_been_called:
raise RuntimeError(
"Step inputs can only be accessed once a step has been called "
"inside a pipeline."
)
return self._inputs
@property
def source_object(self) -> Any:
"""The source object of this step.
This is either a function wrapped by the `@step` decorator or a custom
step class.
Returns:
The source object of this step.
"""
if self._created_by_functional_api():
return self.entrypoint
return self.__class__
@property
def source_code(self) -> str:
"""The source code of this step.
Returns:
The source code of this step.
"""
return inspect.getsource(self.source_object)
@property
def docstring(self) -> Optional[str]:
"""The docstring of this step.
Returns:
The docstring of this step.
"""
return self.__doc__
@property
def caching_parameters(self) -> Dict[str, Any]:
"""Caching parameters for this step.
Returns:
A dictionary containing the caching parameters
"""
parameters = {}
parameters[
STEP_SOURCE_PARAMETER_NAME
] = source_code_utils.get_hashed_source_code(self.source_object)
for name, output in self.configuration.outputs.items():
if output.materializer_source:
key = f"{name}_materializer_source"
materializer_class = source_utils.load(
output.materializer_source
)
parameters[key] = source_code_utils.get_hashed_source_code(
materializer_class
)
return parameters
def _apply_class_configuration(self, options: Dict[str, Any]) -> None:
"""Applies the configurations specified on the step class.
Args:
options: Class configurations.
"""
step_operator = options.pop(PARAM_STEP_OPERATOR, None)
settings = options.pop(PARAM_SETTINGS, None) or {}
output_materializers = options.pop(PARAM_OUTPUT_MATERIALIZERS, None)
extra = options.pop(PARAM_EXTRA_OPTIONS, None)
experiment_tracker = options.pop(PARAM_EXPERIMENT_TRACKER, None)
on_failure = options.pop(PARAM_ON_FAILURE, None)
on_success = options.pop(PARAM_ON_SUCCESS, None)
self.configure(
experiment_tracker=experiment_tracker,
step_operator=step_operator,
output_materializers=output_materializers,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
)
def _verify_and_apply_init_params(self, *args: Any, **kwargs: Any) -> None:
"""Verifies the initialization args and kwargs of this step.
This method makes sure that there is only one parameters object passed
at initialization and that it was passed using the correct name and
type specified in the step declaration.
Args:
*args: The args passed to the init method of this step.
**kwargs: The kwargs passed to the init method of this step.
Raises:
StepInterfaceError: If there are too many arguments or arguments
with a wrong name/type.
"""
maximum_arg_count = 1 if self.PARAMETERS_CLASS else 0
arg_count = len(args) + len(kwargs)
if arg_count > maximum_arg_count:
raise StepInterfaceError(
f"Too many arguments ({arg_count}, expected: "
f"{maximum_arg_count}) passed when creating a "
f"'{self.name}' step."
)
if self.PARAMETERS_FUNCTION_PARAMETER_NAME and self.PARAMETERS_CLASS:
if args:
config = args[0]
elif kwargs:
key, config = kwargs.popitem()
if key != self.PARAMETERS_FUNCTION_PARAMETER_NAME:
raise StepInterfaceError(
f"Unknown keyword argument '{key}' when creating a "
f"'{self.name}' step, only expected a single "
"argument with key "
f"'{self.PARAMETERS_FUNCTION_PARAMETER_NAME}'."
)
else:
# This step requires configuration parameters but no parameters
# object was passed as an argument. The parameters might be
# set via default values in the parameters class or in a
# configuration file, so we continue for now and verify
# that all parameters are set before running the step
return
if not isinstance(config, self.PARAMETERS_CLASS):
raise StepInterfaceError(
f"`{config}` object passed when creating a "
f"'{self.name}' step is not a "
f"`{self.PARAMETERS_CLASS.__name__}` instance."
)
self.configure(parameters=config)
def _validate_input_artifacts(
self, *artifacts: _OutputArtifact, **kw_artifacts: _OutputArtifact
) -> Dict[str, _OutputArtifact]:
"""Verifies and prepares the input artifacts for running this step.
Args:
*artifacts: Positional input artifacts passed to
the __call__ method.
**kw_artifacts: Keyword input artifacts passed to
the __call__ method.
Returns:
Dictionary containing both the positional and keyword input
artifacts.
Raises:
StepInterfaceError: If there are too many or too few artifacts.
"""
input_artifact_keys = list(self.INPUT_SIGNATURE.keys())
if len(artifacts) > len(input_artifact_keys):
raise StepInterfaceError(
f"Too many input artifacts for step '{self.name}'. "
f"This step expects {len(input_artifact_keys)} artifact(s) "
f"but got {len(artifacts) + len(kw_artifacts)}."
)
combined_artifacts = {}
for i, artifact in enumerate(artifacts):
if not isinstance(artifact, BaseStep._OutputArtifact):
raise StepInterfaceError(
f"Wrong argument type (`{type(artifact)}`) for positional "
f"argument {i} of step '{self.name}'. Only outputs "
f"from previous steps can be used as arguments when "
f"connecting steps."
)
key = input_artifact_keys[i]
combined_artifacts[key] = artifact
for key, artifact in kw_artifacts.items():
if key in combined_artifacts:
# an artifact for this key was already set by
# the positional input artifacts
raise StepInterfaceError(
f"Unexpected keyword argument '{key}' for step "
f"'{self.name}'. An artifact for this key was "
f"already passed as a positional argument."
)
if not isinstance(artifact, BaseStep._OutputArtifact):
raise StepInterfaceError(
f"Wrong argument type (`{type(artifact)}`) for argument "
f"'{key}' of step '{self.name}'. Only outputs from "
f"previous steps can be used as arguments when "
f"connecting steps."
)
combined_artifacts[key] = artifact
# check if there are any missing or unexpected artifacts
expected_artifacts = set(self.INPUT_SIGNATURE.keys())
actual_artifacts = set(combined_artifacts.keys())
missing_artifacts = expected_artifacts - actual_artifacts
unexpected_artifacts = actual_artifacts - expected_artifacts
if missing_artifacts:
raise StepInterfaceError(
f"Missing input artifact(s) for step "
f"'{self.name}': {missing_artifacts}."
)
if unexpected_artifacts:
raise StepInterfaceError(
f"Unexpected input artifact(s) for step "
f"'{self.name}': {unexpected_artifacts}. This step "
f"only requires the following artifacts: {expected_artifacts}."
)
return combined_artifacts
def __call__(
self, *artifacts: _OutputArtifact, **kw_artifacts: _OutputArtifact
) -> Union[_OutputArtifact, List[_OutputArtifact]]:
"""Finalizes the step input and output configuration.
Args:
*artifacts: Positional input artifacts passed to
the __call__ method.
**kw_artifacts: Keyword input artifacts passed to
the __call__ method.
Returns:
A single output artifact or a list of output artifacts.
Raises:
StepInterfaceError: If the step has already been called.
"""
if self._has_been_called:
raise StepInterfaceError(
f"Step {self.name} has already been called. A ZenML step "
f"instance can only be called once per pipeline run."
)
self._has_been_called = True
# Prepare the input artifacts and spec
input_artifacts = self._validate_input_artifacts(
*artifacts, **kw_artifacts
)
for name, input_ in input_artifacts.items():
self._upstream_steps.add(input_.step_name)
self._inputs[name] = InputSpec(
step_name=input_.step_name,
output_name=input_.name,
)
config = self._finalize_configuration(input_artifacts=input_artifacts)
# Resolve the returns in the right order.
returns = []
for key in self.OUTPUT_SIGNATURE:
materializer_source = config.outputs[key].materializer_source
output_artifact = BaseStep._OutputArtifact(
name=key,
step_name=self.name,
materializer_source=materializer_source,
)
returns.append(output_artifact)
# If its one return we just return the one channel not as a list
if len(returns) == 1:
return returns[0]
else:
return returns
@property
def name(self) -> str:
"""The name of the step.
Returns:
The name of the step.
"""
return self.configuration.name
@property
def enable_cache(self) -> Optional[bool]:
"""If caching is enabled for the step.
Returns:
If caching is enabled for the step.
"""
return self.configuration.enable_cache
@property
def configuration(self) -> PartialStepConfiguration:
"""The configuration of the step.
Returns:
The configuration of the step.
"""
return self._configuration
def configure(
self: T,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
merge: bool = True,
) -> T:
"""Configures the step.
Configuration merging example:
* `merge==True`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* `merge==False`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Args:
name: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
on_failure: Callback function in event of failure of the step. Can be
a function with three possible parameters, `StepContext`, `BaseParameters`,
and `BaseException`, or a source path to a function of the same specifications
(e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can be
a function with two possible parameters, `StepContext` and `BaseParameters, or
a source path to a function of the same specifications
(e.g. `module.my_function`).
Returns:
The step instance that this method was called on.
"""
from zenml.hooks.hook_validators import resolve_and_validate_hook
def _resolve_if_necessary(
value: Union[str, Source, Type[Any]]
) -> Source:
if isinstance(value, str):
return Source.from_import_path(value)
elif isinstance(value, Source):
return value
else:
return source_utils.resolve(value)
outputs: Dict[str, Dict[str, Source]] = defaultdict(dict)
allowed_output_names = set(self.OUTPUT_SIGNATURE)
if output_materializers:
if not isinstance(output_materializers, Mapping):
# string of materializer class to be used for all outputs
source = _resolve_if_necessary(output_materializers)
output_materializers = {
output_name: source for output_name in allowed_output_names
}
for output_name, materializer in output_materializers.items():
source = _resolve_if_necessary(materializer)
outputs[output_name]["materializer_source"] = source
failure_hook_source = None
if on_failure:
# string of on_failure hook function to be used for this step
failure_hook_source = resolve_and_validate_hook(on_failure)
success_hook_source = None
if on_success:
# string of on_success hook function to be used for this step
success_hook_source = resolve_and_validate_hook(on_success)
if isinstance(parameters, BaseParameters):
parameters = parameters.dict()
values = dict_utils.remove_none_values(
{
"name": name,
"enable_cache": enable_cache,
"enable_artifact_metadata": enable_artifact_metadata,
"enable_artifact_visualization": enable_artifact_visualization,
"experiment_tracker": experiment_tracker,
"step_operator": step_operator,
"parameters": parameters,
"settings": settings,
"outputs": outputs or None,
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config, merge=merge)
return self
def _apply_configuration(
self,
config: StepConfigurationUpdate,
merge: bool = True,
) -> None:
"""Applies an update to the step configuration.
Args:
config: The configuration update.
merge: Whether to merge the updates with the existing configuration
or not. See the `BaseStep.configure(...)` method for a detailed
explanation.
"""
self._validate_configuration(config)
self._configuration = pydantic_utils.update_model(
self._configuration, update=config, recursive=merge
)
logger.debug("Updated step configuration:")
logger.debug(self._configuration)
def _validate_configuration(self, config: StepConfigurationUpdate) -> None:
"""Validates a configuration update.
Args:
config: The configuration update to validate.
"""
settings_utils.validate_setting_keys(list(config.settings))
self._validate_function_parameters(parameters=config.parameters)
self._validate_outputs(outputs=config.outputs)
def _validate_function_parameters(
self, parameters: Dict[str, Any]
) -> None:
"""Validates step function parameters.
Args:
parameters: The parameters to validate.
Raises:
StepInterfaceError: If the step requires no function parameters but
parameters were configured.
"""
if not parameters:
return
if not self.PARAMETERS_CLASS:
raise StepInterfaceError(
f"Function parameters configured for step {self.name} which "
"does not accept any function parameters."
)
def _validate_inputs(
self, inputs: Mapping[str, ArtifactConfiguration]
) -> None:
"""Validates the step input configuration.
Args:
inputs: The configured step inputs.
Raises:
StepInterfaceError: If an input for a non-existent name is
configured.
"""
allowed_input_names = set(self.INPUT_SIGNATURE)
for input_name in inputs.keys():
if input_name not in allowed_input_names:
raise StepInterfaceError(
f"Got unexpected artifact for non-existent "
f"input '{input_name}' in step '{self.name}'. "
f"Only artifacts for the inputs "
f"{allowed_input_names} of this step can"
f" be registered."
)
def _validate_outputs(
self, outputs: Mapping[str, PartialArtifactConfiguration]
) -> None:
"""Validates the step output configuration.
Args:
outputs: The configured step outputs.
Raises:
StepInterfaceError: If an output for a non-existent name is
configured of an output artifact/materializer source does not
resolve to the correct class.
"""
allowed_output_names = set(self.OUTPUT_SIGNATURE)
for output_name, output in outputs.items():
if output_name not in allowed_output_names:
raise StepInterfaceError(
f"Got unexpected materializers for non-existent "
f"output '{output_name}' in step '{self.name}'. "
f"Only materializers for the outputs "
f"{allowed_output_names} of this step can"
f" be registered."
)
if output.materializer_source:
if not source_utils.validate_source_class(
output.materializer_source, expected_class=BaseMaterializer
):
raise StepInterfaceError(
f"Materializer source `{output.materializer_source}` "
f"for output '{output_name}' of step '{self.name}' "
"does not resolve to a `BaseMaterializer` subclass."
)
def _finalize_configuration(
self, input_artifacts: Dict[str, _OutputArtifact]
) -> StepConfiguration:
"""Finalizes the configuration after the step was called.
Once the step was called, we know the outputs of previous steps
and that no additional user configurations will be made. That means
we can now collect the remaining artifact and materializer types
as well as check for the completeness of the step function parameters.
Args:
input_artifacts: The input artifacts of this step.
Returns:
The finalized step configuration.
"""
outputs: Dict[str, Dict[str, Source]] = defaultdict(dict)
for output_name, output_class in self.OUTPUT_SIGNATURE.items():
output = self._configuration.outputs.get(
output_name, PartialArtifactConfiguration()
)
if not output.materializer_source:
materializer_class = materializer_registry[output_class]
outputs[output_name][
"materializer_source"
] = source_utils.resolve(materializer_class)
function_parameters = self._finalize_function_parameters()
values = dict_utils.remove_none_values(
{
"outputs": outputs or None,
"parameters": function_parameters,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config)
inputs = {}
for input_name, artifact in input_artifacts.items():
inputs[input_name] = ArtifactConfiguration(
materializer_source=artifact.materializer_source,
)
self._validate_inputs(inputs)
self._configuration = self._configuration.copy(
update={
"inputs": inputs,
"caching_parameters": self.caching_parameters,
}
)
complete_configuration = StepConfiguration.parse_obj(
self._configuration
)
return complete_configuration
def _finalize_function_parameters(self) -> Dict[str, Any]:
"""Verifies and prepares the config parameters for running this step.
When the step requires config parameters, this method:
- checks if config parameters were set via a config object or file
- tries to set missing config parameters from default values of the
config class
Returns:
Values for the previously unconfigured function parameters.
Raises:
MissingStepParameterError: If no value could be found for one or
more config parameters.
StepInterfaceError: If the parameter class validation failed.
"""
if not self.PARAMETERS_CLASS:
return {}
# we need to store a value for all config keys inside the
# metadata store to make sure caching works as expected
missing_keys = []
values = {}
for name, field in self.PARAMETERS_CLASS.__fields__.items():
if name in self.configuration.parameters:
# a value for this parameter has been set already
values[name] = self.configuration.parameters[name]
continue
if field.required:
# this field has no default value set and therefore needs
# to be passed via an initialized config object
missing_keys.append(name)
else:
# use default value from the pydantic config class
values[name] = field.default
if missing_keys:
raise MissingStepParameterError(
self.name, missing_keys, self.PARAMETERS_CLASS
)
try:
self.PARAMETERS_CLASS(**values)
except ValidationError:
raise StepInterfaceError("Failed to validate function parameters.")
return values
caching_parameters: Dict[str, Any]
property
readonly
Caching parameters for this step.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the caching parameters |
configuration: PartialStepConfiguration
property
readonly
The configuration of the step.
Returns:
Type | Description |
---|---|
PartialStepConfiguration |
The configuration of the step. |
docstring: Optional[str]
property
readonly
The docstring of this step.
Returns:
Type | Description |
---|---|
Optional[str] |
The docstring of this step. |
enable_cache: Optional[bool]
property
readonly
If caching is enabled for the step.
Returns:
Type | Description |
---|---|
Optional[bool] |
If caching is enabled for the step. |
inputs: Dict[str, zenml.config.step_configurations.InputSpec]
property
readonly
Step input specifications.
This depends on the upstream steps in a pipeline and can therefore only be accessed once the step has been called in a pipeline.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If this property is accessed before the step was called in a pipeline. |
Returns:
Type | Description |
---|---|
Dict[str, zenml.config.step_configurations.InputSpec] |
The step input specifications. |
name: str
property
readonly
The name of the step.
Returns:
Type | Description |
---|---|
str |
The name of the step. |
source_code: str
property
readonly
The source code of this step.
Returns:
Type | Description |
---|---|
str |
The source code of this step. |
source_object: Any
property
readonly
The source object of this step.
This is either a function wrapped by the @step
decorator or a custom
step class.
Returns:
Type | Description |
---|---|
Any |
The source object of this step. |
upstream_steps: Set[str]
property
readonly
Names of the upstream steps of this step.
This property will only contain the full set of upstream steps once
it's parent pipeline connect(...)
method was called.
Returns:
Type | Description |
---|---|
Set[str] |
Set of upstream step names. |
__call__(self, *artifacts, **kw_artifacts)
special
Finalizes the step input and output configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*artifacts |
_OutputArtifact |
Positional input artifacts passed to the call method. |
() |
**kw_artifacts |
_OutputArtifact |
Keyword input artifacts passed to the call method. |
{} |
Returns:
Type | Description |
---|---|
Union[zenml.steps.base_step.BaseStep._OutputArtifact, List[zenml.steps.base_step.BaseStep._OutputArtifact]] |
A single output artifact or a list of output artifacts. |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If the step has already been called. |
Source code in zenml/steps/base_step.py
def __call__(
self, *artifacts: _OutputArtifact, **kw_artifacts: _OutputArtifact
) -> Union[_OutputArtifact, List[_OutputArtifact]]:
"""Finalizes the step input and output configuration.
Args:
*artifacts: Positional input artifacts passed to
the __call__ method.
**kw_artifacts: Keyword input artifacts passed to
the __call__ method.
Returns:
A single output artifact or a list of output artifacts.
Raises:
StepInterfaceError: If the step has already been called.
"""
if self._has_been_called:
raise StepInterfaceError(
f"Step {self.name} has already been called. A ZenML step "
f"instance can only be called once per pipeline run."
)
self._has_been_called = True
# Prepare the input artifacts and spec
input_artifacts = self._validate_input_artifacts(
*artifacts, **kw_artifacts
)
for name, input_ in input_artifacts.items():
self._upstream_steps.add(input_.step_name)
self._inputs[name] = InputSpec(
step_name=input_.step_name,
output_name=input_.name,
)
config = self._finalize_configuration(input_artifacts=input_artifacts)
# Resolve the returns in the right order.
returns = []
for key in self.OUTPUT_SIGNATURE:
materializer_source = config.outputs[key].materializer_source
output_artifact = BaseStep._OutputArtifact(
name=key,
step_name=self.name,
materializer_source=materializer_source,
)
returns.append(output_artifact)
# If its one return we just return the one channel not as a list
if len(returns) == 1:
return returns[0]
else:
return returns
__init__(self, *args, **kwargs)
special
Initializes a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments passed to the step. |
() |
**kwargs |
Any |
Keyword arguments passed to the step. |
{} |
Source code in zenml/steps/base_step.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes a step.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
"""
self._has_been_called = False
self._upstream_steps: Set[str] = set()
self._inputs: Dict[str, InputSpec] = {}
kwargs = {**self.INSTANCE_CONFIGURATION, **kwargs}
name = kwargs.pop(PARAM_STEP_NAME, None) or self.__class__.__name__
# This value is only used in `BaseStep.__created_by_functional_api()`
kwargs.pop(PARAM_CREATED_BY_FUNCTIONAL_API, None)
requires_context = bool(self.CONTEXT_PARAMETER_NAME)
enable_cache = kwargs.pop(PARAM_ENABLE_CACHE, None)
if enable_cache is None:
if requires_context:
# Using the StepContext inside a step provides access to
# external resources which might influence the step execution.
# We therefore disable caching unless it is explicitly enabled
enable_cache = False
logger.debug(
"Step '%s': Step context required and caching not "
"explicitly enabled.",
name,
)
logger.debug(
"Step '%s': Caching %s.",
name,
"enabled" if enable_cache is not False else "disabled",
)
enable_artifact_metadata = kwargs.pop(
PARAM_ENABLE_ARTIFACT_METADATA, None
)
logger.debug(
"Step '%s': Artifact metadata %s.",
name,
"enabled" if enable_artifact_metadata is not False else "disabled",
)
enable_artifact_visualization = kwargs.pop(
PARAM_ENABLE_ARTIFACT_VISUALIZATION, None
)
logger.debug(
"Step '%s': Artifact visualization %s.",
name,
"enabled"
if enable_artifact_visualization is not False
else "disabled",
)
self._configuration = PartialStepConfiguration(
name=name,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
)
self._apply_class_configuration(kwargs)
self._verify_and_apply_init_params(*args, **kwargs)
after(self, step)
Adds an upstream step to this step.
Calling this method makes sure this step only starts running once the given step has successfully finished executing.
Note: This can only be called inside the pipeline connect function
which is decorated with the @pipeline
decorator. Any calls outside
this function will be ignored.
Examples:
The following pipeline will run its steps sequentially in the following order: step_2 -> step_1 -> step_3
@pipeline
def example_pipeline(step_1, step_2, step_3):
step_1.after(step_2)
step_3(step_1(), step_2())
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
A step which should finish executing before this step is started. |
required |
Source code in zenml/steps/base_step.py
def after(self, step: "BaseStep") -> None:
"""Adds an upstream step to this step.
Calling this method makes sure this step only starts running once the
given step has successfully finished executing.
**Note**: This can only be called inside the pipeline connect function
which is decorated with the `@pipeline` decorator. Any calls outside
this function will be ignored.
Example:
The following pipeline will run its steps sequentially in the following
order: step_2 -> step_1 -> step_3
```python
@pipeline
def example_pipeline(step_1, step_2, step_3):
step_1.after(step_2)
step_3(step_1(), step_2())
```
Args:
step: A step which should finish executing before this step is
started.
"""
self._upstream_steps.add(step.name)
configure(self, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, merge=True)
Configures the step.
Configuration merging example:
* merge==True
:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* merge==False
:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
The name of the step. |
None |
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[ParametersOrDict] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
merge |
bool |
If |
True |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be
a function with three possible parameters, |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be
a function with two possible parameters, |
None |
Returns:
Type | Description |
---|---|
~T |
The step instance that this method was called on. |
Source code in zenml/steps/base_step.py
def configure(
self: T,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
merge: bool = True,
) -> T:
"""Configures the step.
Configuration merging example:
* `merge==True`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* `merge==False`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Args:
name: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
on_failure: Callback function in event of failure of the step. Can be
a function with three possible parameters, `StepContext`, `BaseParameters`,
and `BaseException`, or a source path to a function of the same specifications
(e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can be
a function with two possible parameters, `StepContext` and `BaseParameters, or
a source path to a function of the same specifications
(e.g. `module.my_function`).
Returns:
The step instance that this method was called on.
"""
from zenml.hooks.hook_validators import resolve_and_validate_hook
def _resolve_if_necessary(
value: Union[str, Source, Type[Any]]
) -> Source:
if isinstance(value, str):
return Source.from_import_path(value)
elif isinstance(value, Source):
return value
else:
return source_utils.resolve(value)
outputs: Dict[str, Dict[str, Source]] = defaultdict(dict)
allowed_output_names = set(self.OUTPUT_SIGNATURE)
if output_materializers:
if not isinstance(output_materializers, Mapping):
# string of materializer class to be used for all outputs
source = _resolve_if_necessary(output_materializers)
output_materializers = {
output_name: source for output_name in allowed_output_names
}
for output_name, materializer in output_materializers.items():
source = _resolve_if_necessary(materializer)
outputs[output_name]["materializer_source"] = source
failure_hook_source = None
if on_failure:
# string of on_failure hook function to be used for this step
failure_hook_source = resolve_and_validate_hook(on_failure)
success_hook_source = None
if on_success:
# string of on_success hook function to be used for this step
success_hook_source = resolve_and_validate_hook(on_success)
if isinstance(parameters, BaseParameters):
parameters = parameters.dict()
values = dict_utils.remove_none_values(
{
"name": name,
"enable_cache": enable_cache,
"enable_artifact_metadata": enable_artifact_metadata,
"enable_artifact_visualization": enable_artifact_visualization,
"experiment_tracker": experiment_tracker,
"step_operator": step_operator,
"parameters": parameters,
"settings": settings,
"outputs": outputs or None,
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config, merge=merge)
return self
entrypoint(self, *args, **kwargs)
Abstract method for core step logic.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments passed to the step. |
() |
**kwargs |
Any |
Keyword arguments passed to the step. |
{} |
Returns:
Type | Description |
---|---|
Any |
The output of the step. |
Source code in zenml/steps/base_step.py
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
Returns:
The output of the step.
"""
load_from_source(source)
classmethod
Loads a step from source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
Union[zenml.config.source.Source, str] |
The path to the step source. |
required |
Returns:
Type | Description |
---|---|
BaseStep |
The loaded step. |
Source code in zenml/steps/base_step.py
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
"""Loads a step from source.
Args:
source: The path to the step source.
Returns:
The loaded step.
"""
step_class: Type[BaseStep] = source_utils.load_and_validate_class(
source, expected_class=BaseStep
)
return step_class()
BaseStepMeta (type)
Metaclass for BaseStep
.
Checks whether everything passed in: * Has a matching materializer, * Is a subclass of the Config class, * Is typed correctly.
Source code in zenml/steps/base_step.py
class BaseStepMeta(type):
"""Metaclass for `BaseStep`.
Checks whether everything passed in:
* Has a matching materializer,
* Is a subclass of the Config class,
* Is typed correctly.
"""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
"""Set up a new class with a qualified spec.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The attributes of the class.
Returns:
The new class.
Raises:
StepInterfaceError: When unable to create the step.
"""
from zenml.steps.base_parameters import BaseParameters
dct.setdefault(INSTANCE_CONFIGURATION, {})
cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
cls.INPUT_SIGNATURE = {}
cls.OUTPUT_SIGNATURE = {}
cls.PARAMETERS_FUNCTION_PARAMETER_NAME = None
cls.PARAMETERS_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None
# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
inspect.unwrap(cls.entrypoint)
)
if bases:
# We're not creating the abstract `BaseStep` class
# but a concrete implementation. Make sure the step function
# signature does not contain variable *args or **kwargs
variable_arguments = None
if step_function_signature.varargs:
variable_arguments = f"*{step_function_signature.varargs}"
elif step_function_signature.varkw:
variable_arguments = f"**{step_function_signature.varkw}"
if variable_arguments:
raise StepInterfaceError(
f"Unable to create step '{name}' with variable arguments "
f"'{variable_arguments}'. Please make sure your step "
f"functions are defined with a fixed amount of arguments."
)
step_function_args = (
step_function_signature.args + step_function_signature.kwonlyargs
)
# Remove 'self' from the signature if it exists
if step_function_args and step_function_args[0] == "self":
step_function_args.pop(0)
# Verify the input arguments of the step function
for arg in step_function_args:
arg_type = step_function_signature.annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)
if not arg_type:
raise StepInterfaceError(
f"Missing type annotation for argument '{arg}' when "
f"trying to create step '{name}'. Please make sure to "
f"include type annotations for all your step inputs "
f"and outputs."
)
if issubclass(arg_type, BaseParameters):
# Raise an error if we already found a config in the signature
if cls.PARAMETERS_CLASS is not None:
raise StepInterfaceError(
f"Found multiple parameter arguments "
f"('{cls.PARAMETERS_FUNCTION_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `Parameters` subclass as input "
f"argument for a step."
)
cls.PARAMETERS_FUNCTION_PARAMETER_NAME = arg
cls.PARAMETERS_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get be defined later, so we simply store the
# argument name and type for later use.
cls.INPUT_SIGNATURE.update({arg: arg_type})
# Parse the returns of the step function
if "return" not in step_function_signature.annotations:
raise StepInterfaceError(
"Missing return type annotation when trying to create step "
f"'{name}'. Please make sure to include type annotations for "
"all your step inputs and outputs. If your step returns "
"nothing, please annotate it with `-> None`."
)
cls.OUTPUT_SIGNATURE = parse_return_type_annotations(
step_function_signature.annotations,
)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Set up a new class with a qualified spec.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class. |
required |
dct |
Dict[str, Any] |
The attributes of the class. |
required |
Returns:
Type | Description |
---|---|
BaseStepMeta |
The new class. |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
When unable to create the step. |
Source code in zenml/steps/base_step.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
"""Set up a new class with a qualified spec.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The attributes of the class.
Returns:
The new class.
Raises:
StepInterfaceError: When unable to create the step.
"""
from zenml.steps.base_parameters import BaseParameters
dct.setdefault(INSTANCE_CONFIGURATION, {})
cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
cls.INPUT_SIGNATURE = {}
cls.OUTPUT_SIGNATURE = {}
cls.PARAMETERS_FUNCTION_PARAMETER_NAME = None
cls.PARAMETERS_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None
# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
inspect.unwrap(cls.entrypoint)
)
if bases:
# We're not creating the abstract `BaseStep` class
# but a concrete implementation. Make sure the step function
# signature does not contain variable *args or **kwargs
variable_arguments = None
if step_function_signature.varargs:
variable_arguments = f"*{step_function_signature.varargs}"
elif step_function_signature.varkw:
variable_arguments = f"**{step_function_signature.varkw}"
if variable_arguments:
raise StepInterfaceError(
f"Unable to create step '{name}' with variable arguments "
f"'{variable_arguments}'. Please make sure your step "
f"functions are defined with a fixed amount of arguments."
)
step_function_args = (
step_function_signature.args + step_function_signature.kwonlyargs
)
# Remove 'self' from the signature if it exists
if step_function_args and step_function_args[0] == "self":
step_function_args.pop(0)
# Verify the input arguments of the step function
for arg in step_function_args:
arg_type = step_function_signature.annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)
if not arg_type:
raise StepInterfaceError(
f"Missing type annotation for argument '{arg}' when "
f"trying to create step '{name}'. Please make sure to "
f"include type annotations for all your step inputs "
f"and outputs."
)
if issubclass(arg_type, BaseParameters):
# Raise an error if we already found a config in the signature
if cls.PARAMETERS_CLASS is not None:
raise StepInterfaceError(
f"Found multiple parameter arguments "
f"('{cls.PARAMETERS_FUNCTION_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `Parameters` subclass as input "
f"argument for a step."
)
cls.PARAMETERS_FUNCTION_PARAMETER_NAME = arg
cls.PARAMETERS_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get be defined later, so we simply store the
# argument name and type for later use.
cls.INPUT_SIGNATURE.update({arg: arg_type})
# Parse the returns of the step function
if "return" not in step_function_signature.annotations:
raise StepInterfaceError(
"Missing return type annotation when trying to create step "
f"'{name}'. Please make sure to include type annotations for "
"all your step inputs and outputs. If your step returns "
"nothing, please annotate it with `-> None`."
)
cls.OUTPUT_SIGNATURE = parse_return_type_annotations(
step_function_signature.annotations,
)
return cls
step_context
Step context class.
StepContext
Provides additional context inside a step function.
This class is used to access pipelines, materializers, and artifacts
inside a step function. To use it, add a StepContext
object
to the signature of your step function like this:
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
You do not need to create a StepContext
object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the StepContext
and automatically pass it when executing your
step.
Note: When using a StepContext
inside a step, ZenML disables caching
for this step by default as the context provides access to external
resources which might influence the result of your step execution. To
enable caching anyway, explicitly enable it in the @step
decorator or when
initializing your custom step class.
Source code in zenml/steps/step_context.py
class StepContext:
"""Provides additional context inside a step function.
This class is used to access pipelines, materializers, and artifacts
inside a step function. To use it, add a `StepContext` object
to the signature of your step function like this:
```python
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
```
You do not need to create a `StepContext` object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the `StepContext` and automatically pass it when executing your
step.
**Note**: When using a `StepContext` inside a step, ZenML disables caching
for this step by default as the context provides access to external
resources which might influence the result of your step execution. To
enable caching anyway, explicitly enable it in the `@step` decorator or when
initializing your custom step class.
"""
def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifact_uris: Dict[str, str],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifact_uris: The output artifacts of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifact_uris.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts URIs for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact URI "
f"keys: {set(output_artifact_uris)}"
)
self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifact_uris[key]
)
for key in output_materializers.keys()
}
self._stack = Client().active_stack
def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step does not have any outputs."
)
if not output_name and output_count > 1:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)
if output_name:
if output_name not in self._outputs:
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step '{self.step_name}'. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))
@property
def stack(self) -> Optional["Stack"]:
"""Returns the current active stack.
Returns:
The current active stack or None.
"""
return self._stack
@property
def pipeline_name(self) -> Optional[str]:
"""Returns the current pipeline name.
Returns:
The current pipeline name or None.
"""
env = Environment().step_environment
return env.pipeline_name
@property
def run_name(self) -> Optional[str]:
"""Returns the current run name.
Returns:
The current run name or None.
"""
env = Environment().step_environment
return env.run_name
@property
def step_run_info(self) -> "StepRunInfo":
"""Info about the currently running step.
Returns:
Info about the currently running step.
"""
env = Environment().step_environment
return env.step_run_info
@property
def cache_enabled(self) -> bool:
"""Returns whether cache is enabled for the step.
Returns:
True if cache is enabled for the step, otherwise False.
"""
env = Environment().step_environment
return env.cache_enabled
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
"""
materializer_class, artifact_uri = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact_uri)
def get_output_artifact_uri(
self, output_name: Optional[str] = None
) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
"""
return self._get_output(output_name).artifact_uri
cache_enabled: bool
property
readonly
Returns whether cache is enabled for the step.
Returns:
Type | Description |
---|---|
bool |
True if cache is enabled for the step, otherwise False. |
pipeline_name: Optional[str]
property
readonly
Returns the current pipeline name.
Returns:
Type | Description |
---|---|
Optional[str] |
The current pipeline name or None. |
run_name: Optional[str]
property
readonly
Returns the current run name.
Returns:
Type | Description |
---|---|
Optional[str] |
The current run name or None. |
stack: Optional[Stack]
property
readonly
Returns the current active stack.
Returns:
Type | Description |
---|---|
Optional[Stack] |
The current active stack or None. |
step_run_info: StepRunInfo
property
readonly
Info about the currently running step.
Returns:
Type | Description |
---|---|
StepRunInfo |
Info about the currently running step. |
__init__(self, step_name, output_materializers, output_artifact_uris)
special
Initializes a StepContext instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step that this context is used in. |
required |
output_materializers |
Dict[str, Type[BaseMaterializer]] |
The output materializers of the step that this context is used in. |
required |
output_artifact_uris |
Dict[str, str] |
The output artifacts of the step that this context is used in. |
required |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the keys of the output materializers and output artifacts do not match. |
Source code in zenml/steps/step_context.py
def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifact_uris: Dict[str, str],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifact_uris: The output artifacts of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifact_uris.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts URIs for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact URI "
f"keys: {set(output_artifact_uris)}"
)
self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifact_uris[key]
)
for key in output_materializers.keys()
}
self._stack = Client().active_stack
get_output_artifact_uri(self, output_name=None)
Returns the artifact URI for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the URI. If no name is given and the step only has a single output, the URI of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
Returns:
Type | Description |
---|---|
str |
Artifact URI for the given output. |
Source code in zenml/steps/step_context.py
def get_output_artifact_uri(
self, output_name: Optional[str] = None
) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
"""
return self._get_output(output_name).artifact_uri
get_output_materializer(self, output_name=None, custom_materializer_class=None)
Returns a materializer for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the materializer. If no name is given and the step only has a single output, the materializer of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
custom_materializer_class |
Optional[Type[BaseMaterializer]] |
If given, this |
None |
Returns:
Type | Description |
---|---|
BaseMaterializer |
A materializer initialized with the output artifact for the given output. |
Source code in zenml/steps/step_context.py
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
"""
materializer_class, artifact_uri = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact_uri)
StepContextOutput (tuple)
Tuple containing materializer class and URI for a step output.
Source code in zenml/steps/step_context.py
class StepContextOutput(NamedTuple):
"""Tuple containing materializer class and URI for a step output."""
materializer_class: Type["BaseMaterializer"]
artifact_uri: str
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/steps/step_context.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, materializer_class, artifact_uri)
special
staticmethod
Create new instance of StepContextOutput(materializer_class, artifact_uri)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/steps/step_context.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
step_decorator
Step decorator function.
step(_func=None, *, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, experiment_tracker=None, step_operator=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None)
Outer decorator function for the creation of a ZenML step.
In order to be able to work with parameters such as name
, it features a
nested decorator structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_func |
Optional[~F] |
The decorated function. |
None |
name |
Optional[str] |
The name of the step. If left empty, the name of the decorated function will be used as a fallback. |
None |
enable_cache |
Optional[bool] |
Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default unless the step
requires a |
None |
enable_artifact_metadata |
Optional[bool] |
Specify whether metadata is enabled for this step. If no value is passed, metadata is enabled by default. |
None |
enable_artifact_visualization |
Optional[bool] |
Specify whether visualization is enabled for this step. If no value is passed, visualization is enabled by default. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Dict[str, SettingsOrDict]] |
Settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be
a function with three possible parameters,
|
None |
on_success |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be
a function with two possible parameters, |
None |
Returns:
Type | Description |
---|---|
Union[Type[zenml.steps.base_step.BaseStep], Callable[[~F], Type[zenml.steps.base_step.BaseStep]]] |
The inner decorator which creates the step class based on the ZenML BaseStep |
Source code in zenml/steps/step_decorator.py
def step(
_func: Optional[F] = None,
*,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
output_materializers: Optional["OutputMaterializersSpecification"] = None,
settings: Optional[Dict[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
"""Outer decorator function for the creation of a ZenML step.
In order to be able to work with parameters such as `name`, it features a
nested decorator structure.
Args:
_func: The decorated function.
name: The name of the step. If left empty, the name of the decorated
function will be used as a fallback.
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default unless the step
requires a `StepContext` (see
`zenml.steps.step_context.StepContext` for more information).
enable_artifact_metadata: Specify whether metadata is enabled for this
step. If no value is passed, metadata is enabled by default.
enable_artifact_visualization: Specify whether visualization is enabled
for this step. If no value is passed, visualization is enabled by
default.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: Settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can be
a function with three possible parameters,
`StepContext`, `BaseParameters`, and `BaseException`,
or a source path to a function of the same specifications
(e.g. `module.my_function`).
on_success: Callback function in event of failure of the step. Can be
a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same specifications
(e.g. `module.my_function`).
Returns:
The inner decorator which creates the step class based on the
ZenML BaseStep
"""
def inner_decorator(func: F) -> Type[BaseStep]:
"""Inner decorator function for the creation of a ZenML Step.
Args:
func: types.FunctionType, this function will be used as the
"process" method of the generated Step.
Returns:
The class of a newly generated ZenML Step.
"""
return type( # noqa
func.__name__,
(BaseStep,),
{
STEP_INNER_FUNC_NAME: staticmethod(func),
INSTANCE_CONFIGURATION: {
PARAM_STEP_NAME: name,
PARAM_CREATED_BY_FUNCTIONAL_API: True,
PARAM_ENABLE_CACHE: enable_cache,
PARAM_ENABLE_ARTIFACT_METADATA: enable_artifact_metadata,
PARAM_ENABLE_ARTIFACT_VISUALIZATION: enable_artifact_visualization,
PARAM_EXPERIMENT_TRACKER: experiment_tracker,
PARAM_STEP_OPERATOR: step_operator,
PARAM_OUTPUT_MATERIALIZERS: output_materializers,
PARAM_SETTINGS: settings,
PARAM_EXTRA_OPTIONS: extra,
PARAM_ON_FAILURE: on_failure,
PARAM_ON_SUCCESS: on_success,
},
"__module__": func.__module__,
"__doc__": func.__doc__,
},
)
if _func is None:
return inner_decorator
else:
return inner_decorator(_func)
step_environment
Step environment class.
StepEnvironment (BaseEnvironmentComponent)
Added information about a step runtime inside a step function.
This takes the form of an Environment component. This class can be used from within a pipeline step implementation to access additional information about the runtime parameters of a pipeline step, such as the pipeline name, pipeline run ID and other pipeline runtime information. To use it, access it inside your step function like this:
from zenml.environment import Environment
@step
def my_step(...)
env = Environment().step_environment
do_something_with(env.pipeline_name, env.run_name, env.step_name)
Source code in zenml/steps/step_environment.py
class StepEnvironment(BaseEnvironmentComponent):
"""Added information about a step runtime inside a step function.
This takes the form of an Environment component. This class can be used from
within a pipeline step implementation to access additional information about
the runtime parameters of a pipeline step, such as the pipeline name,
pipeline run ID and other pipeline runtime information. To use it, access it
inside your step function like this:
```python
from zenml.environment import Environment
@step
def my_step(...)
env = Environment().step_environment
do_something_with(env.pipeline_name, env.run_name, env.step_name)
```
"""
NAME = STEP_ENVIRONMENT_NAME
def __init__(
self,
step_run_info: "StepRunInfo",
cache_enabled: bool,
):
"""Initialize the environment of the currently running step.
Args:
step_run_info: Info about the currently running step.
cache_enabled: Whether caching is enabled for the current step run.
"""
super().__init__()
self._step_run_info = step_run_info
self._cache_enabled = cache_enabled
@property
def pipeline_name(self) -> str:
"""The name of the currently running pipeline.
Returns:
The name of the currently running pipeline.
"""
return self._step_run_info.pipeline.name
@property
def run_name(self) -> str:
"""The name of the current pipeline run.
Returns:
The name of the current pipeline run.
"""
return self._step_run_info.run_name
@property
def pipeline_run_id(self) -> str:
"""The ID of the current pipeline run.
Returns:
The ID of the current pipeline run.
"""
logger.warning(
"`StepContext.pipeline_run_id` is deprecated. Use "
"`StepContext.run_name` instead."
)
return self.run_name
@property
def step_name(self) -> str:
"""The name of the currently running step.
Returns:
The name of the currently running step.
"""
return self._step_run_info.pipeline_step_name
@property
def step_run_info(self) -> "StepRunInfo":
"""Info about the currently running step.
Returns:
Info about the currently running step.
"""
return self._step_run_info
@property
def cache_enabled(self) -> bool:
"""Returns whether cache is enabled for the step.
Returns:
True if cache is enabled for the step, otherwise False.
"""
return self._cache_enabled
cache_enabled: bool
property
readonly
Returns whether cache is enabled for the step.
Returns:
Type | Description |
---|---|
bool |
True if cache is enabled for the step, otherwise False. |
pipeline_name: str
property
readonly
The name of the currently running pipeline.
Returns:
Type | Description |
---|---|
str |
The name of the currently running pipeline. |
pipeline_run_id: str
property
readonly
The ID of the current pipeline run.
Returns:
Type | Description |
---|---|
str |
The ID of the current pipeline run. |
run_name: str
property
readonly
The name of the current pipeline run.
Returns:
Type | Description |
---|---|
str |
The name of the current pipeline run. |
step_name: str
property
readonly
The name of the currently running step.
Returns:
Type | Description |
---|---|
str |
The name of the currently running step. |
step_run_info: StepRunInfo
property
readonly
Info about the currently running step.
Returns:
Type | Description |
---|---|
StepRunInfo |
Info about the currently running step. |
__init__(self, step_run_info, cache_enabled)
special
Initialize the environment of the currently running step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_info |
StepRunInfo |
Info about the currently running step. |
required |
cache_enabled |
bool |
Whether caching is enabled for the current step run. |
required |
Source code in zenml/steps/step_environment.py
def __init__(
self,
step_run_info: "StepRunInfo",
cache_enabled: bool,
):
"""Initialize the environment of the currently running step.
Args:
step_run_info: Info about the currently running step.
cache_enabled: Whether caching is enabled for the current step run.
"""
super().__init__()
self._step_run_info = step_run_info
self._cache_enabled = cache_enabled
step_output
Step output class.
Output
A named tuple with a default name that cannot be overridden.
Source code in zenml/steps/step_output.py
class Output(object):
"""A named tuple with a default name that cannot be overridden."""
def __init__(self, **kwargs: Type[Any]):
"""Initializes the output.
Args:
**kwargs: The output values.
"""
# TODO [ENG-161]: do we even need the named tuple here or is
# a list of tuples (name, Type) sufficient?
self.outputs = NamedTuple("ZenOutput", **kwargs) # type: ignore[misc]
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
"""Yields a tuple of type (output_name, output_type).
Yields:
A tuple of type (output_name, output_type).
"""
yield from self.outputs.__annotations__.items()
__init__(self, **kwargs)
special
Initializes the output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Type[Any] |
The output values. |
{} |
Source code in zenml/steps/step_output.py
def __init__(self, **kwargs: Type[Any]):
"""Initializes the output.
Args:
**kwargs: The output values.
"""
# TODO [ENG-161]: do we even need the named tuple here or is
# a list of tuples (name, Type) sufficient?
self.outputs = NamedTuple("ZenOutput", **kwargs) # type: ignore[misc]
items(self)
Yields a tuple of type (output_name, output_type).
Yields:
Type | Description |
---|---|
Iterator[Tuple[str, Type[Any]]] |
A tuple of type (output_name, output_type). |
Source code in zenml/steps/step_output.py
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
"""Yields a tuple of type (output_name, output_type).
Yields:
A tuple of type (output_name, output_type).
"""
yield from self.outputs.__annotations__.items()
utils
Utility functions and classes to run ZenML steps.
parse_return_type_annotations(step_annotations)
Parse the returns of a step function into a dict of resolved types.
Called within BaseStepMeta.__new__()
to define cls.OUTPUT_SIGNATURE
.
Called within Do()
to resolve type annotations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_annotations |
Dict[str, Any] |
Type annotations of the step function. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Output signature of the new step class. |
Source code in zenml/steps/utils.py
def parse_return_type_annotations(
step_annotations: Dict[str, Any]
) -> Dict[str, Any]:
"""Parse the returns of a step function into a dict of resolved types.
Called within `BaseStepMeta.__new__()` to define `cls.OUTPUT_SIGNATURE`.
Called within `Do()` to resolve type annotations.
Args:
step_annotations: Type annotations of the step function.
Returns:
Output signature of the new step class.
"""
return_type = step_annotations.get("return", None)
if return_type is None:
return {}
# Cast simple output types to `Output`.
if not isinstance(return_type, Output):
return_type = Output(**{SINGLE_RETURN_OUT_NAME: return_type})
# Resolve type annotations of all outputs and save in new dict.
output_signature = {
output_name: resolve_type_annotation(output_type)
for output_name, output_type in return_type.items()
}
return output_signature
resolve_type_annotation(obj)
Returns the non-generic class for generic aliases of the typing module.
If the input is no generic typing alias, the input itself is returned.
Example: if the input object is typing.Dict
, this method will return the
concrete class dict
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The object to resolve. |
required |
Returns:
Type | Description |
---|---|
Any |
The non-generic class for generic aliases of the typing module. |
Source code in zenml/steps/utils.py
def resolve_type_annotation(obj: Any) -> Any:
"""Returns the non-generic class for generic aliases of the typing module.
If the input is no generic typing alias, the input itself is returned.
Example: if the input object is `typing.Dict`, this method will return the
concrete class `dict`.
Args:
obj: The object to resolve.
Returns:
The non-generic class for generic aliases of the typing module.
"""
from typing import _GenericAlias # type: ignore[attr-defined]
if sys.version_info >= (3, 8):
return typing.get_origin(obj) or obj
else:
# python 3.7
if isinstance(obj, _GenericAlias):
return obj.__origin__
else:
return obj