Steps
zenml.steps
special
Initializer for ZenML steps.
A step is a single piece or stage of a ZenML pipeline. Think of each step as being one of the nodes of a Directed Acyclic Graph (or DAG). Steps are responsible for one aspect of processing or interacting with the data / artifacts in the pipeline.
Conceptually, a Step is a discrete and independent part of a pipeline that is responsible for one particular aspect of data manipulation inside a ZenML pipeline.
Steps can be subclassed from the BaseStep
class, or used via our @step
decorator.
base_parameters
Step parameters.
BaseParameters (BaseModel)
pydantic-model
Base class to pass parameters into a step.
Source code in zenml/steps/base_parameters.py
class BaseParameters(BaseModel):
"""Base class to pass parameters into a step."""
base_step
Base Step for ZenML.
BaseStep
Abstract base class for all ZenML steps.
Source code in zenml/steps/base_step.py
class BaseStep(metaclass=BaseStepMeta):
"""Abstract base class for all ZenML steps."""
def __init__(
self,
*args: Any,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
**kwargs: Any,
) -> None:
"""Initializes a step.
Args:
*args: Positional arguments passed to the step.
name: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with three possible parameters, `StepContext`,
`BaseParameters`, and `BaseException`, or a source path to a
function of the same specifications (e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can
be a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same
specifications (e.g. `module.my_function`).
**kwargs: Keyword arguments passed to the step.
"""
self._upstream_steps: Set["BaseStep"] = set()
self.entrypoint_definition = validate_entrypoint_function(
self.entrypoint, reserved_arguments=["after", "id"]
)
name = name or self.__class__.__name__
requires_context = self.entrypoint_definition.context is not None
if enable_cache is None:
if requires_context:
# Using the StepContext inside a step provides access to
# external resources which might influence the step execution.
# We therefore disable caching unless it is explicitly enabled
enable_cache = False
logger.debug(
"Step '%s': Step context required and caching not "
"explicitly enabled.",
name,
)
logger.debug(
"Step '%s': Caching %s.",
name,
"enabled" if enable_cache is not False else "disabled",
)
logger.debug(
"Step '%s': Artifact metadata %s.",
name,
"enabled" if enable_artifact_metadata is not False else "disabled",
)
logger.debug(
"Step '%s': Artifact visualization %s.",
name,
"enabled"
if enable_artifact_visualization is not False
else "disabled",
)
self._configuration = PartialStepConfiguration(
name=name,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
)
self.configure(
experiment_tracker=experiment_tracker,
step_operator=step_operator,
output_materializers=output_materializers,
parameters=parameters,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
)
self._verify_and_apply_init_params(*args, **kwargs)
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
Returns:
The output of the step.
"""
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
"""Loads a step from source.
Args:
source: The path to the step source.
Returns:
The loaded step.
Raises:
ValueError: If the source is not a valid step source.
"""
obj = source_utils.load(source)
if isinstance(obj, BaseStep):
return obj
elif isinstance(obj, type) and issubclass(obj, BaseStep):
return obj()
else:
raise ValueError("Invalid step source.")
def resolve(self) -> Source:
"""Resolves the step.
Returns:
The step source.
"""
return source_utils.resolve(self.__class__)
@property
def upstream_steps(self) -> Set["BaseStep"]:
"""Names of the upstream steps of this step.
This property will only contain the full set of upstream steps once
it's parent pipeline `connect(...)` method was called.
Returns:
Set of upstream step names.
"""
return self._upstream_steps
def after(self, step: "BaseStep") -> None:
"""Adds an upstream step to this step.
Calling this method makes sure this step only starts running once the
given step has successfully finished executing.
**Note**: This can only be called inside the pipeline connect function
which is decorated with the `@pipeline` decorator. Any calls outside
this function will be ignored.
Example:
The following pipeline will run its steps sequentially in the following
order: step_2 -> step_1 -> step_3
```python
@pipeline
def example_pipeline(step_1, step_2, step_3):
step_1.after(step_2)
step_3(step_1(), step_2())
```
Args:
step: A step which should finish executing before this step is
started.
"""
self._upstream_steps.add(step)
@property
def source_object(self) -> Any:
"""The source object of this step.
Returns:
The source object of this step.
"""
return self.__class__
@property
def source_code(self) -> str:
"""The source code of this step.
Returns:
The source code of this step.
"""
return inspect.getsource(self.source_object)
@property
def docstring(self) -> Optional[str]:
"""The docstring of this step.
Returns:
The docstring of this step.
"""
return self.__doc__
@property
def caching_parameters(self) -> Dict[str, Any]:
"""Caching parameters for this step.
Returns:
A dictionary containing the caching parameters
"""
parameters = {}
parameters[
STEP_SOURCE_PARAMETER_NAME
] = source_code_utils.get_hashed_source_code(self.source_object)
for name, output in self.configuration.outputs.items():
if output.materializer_source:
key = f"{name}_materializer_source"
hash_ = hashlib.md5()
for source in output.materializer_source:
materializer_class = source_utils.load(source)
code_hash = source_code_utils.get_hashed_source_code(
materializer_class
)
hash_.update(code_hash.encode())
parameters[key] = hash_.hexdigest()
return parameters
def _verify_and_apply_init_params(self, *args: Any, **kwargs: Any) -> None:
"""Verifies the initialization args and kwargs of this step.
This method makes sure that there is only one parameters object passed
at initialization and that it was passed using the correct name and
type specified in the step declaration.
Args:
*args: The args passed to the init method of this step.
**kwargs: The kwargs passed to the init method of this step.
Raises:
StepInterfaceError: If there are too many arguments or arguments
with a wrong name/type.
"""
maximum_arg_count = (
1 if self.entrypoint_definition.legacy_params else 0
)
arg_count = len(args) + len(kwargs)
if arg_count > maximum_arg_count:
raise StepInterfaceError(
f"Too many arguments ({arg_count}, expected: "
f"{maximum_arg_count}) passed when creating a "
f"'{self.name}' step."
)
if self.entrypoint_definition.legacy_params:
if args:
config = args[0]
elif kwargs:
key, config = kwargs.popitem()
if key != self.entrypoint_definition.legacy_params.name:
raise StepInterfaceError(
f"Unknown keyword argument '{key}' when creating a "
f"'{self.name}' step, only expected a single "
"argument with key "
f"'{self.entrypoint_definition.legacy_params.name}'."
)
else:
# This step requires configuration parameters but no parameters
# object was passed as an argument. The parameters might be
# set via default values in the parameters class or in a
# configuration file, so we continue for now and verify
# that all parameters are set before running the step
return
if not isinstance(
config, self.entrypoint_definition.legacy_params.annotation
):
raise StepInterfaceError(
f"`{config}` object passed when creating a "
f"'{self.name}' step is not a "
f"`{self.entrypoint_definition.legacy_params.annotation.__name__} "
"` instance."
)
self.configure(parameters=config)
def _parse_call_args(
self, *args: Any, **kwargs: Any
) -> Tuple[
Dict[str, "StepArtifact"],
Dict[str, "ExternalArtifact"],
Dict[str, Any],
]:
"""Parses the call args for the step entrypoint.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Raises:
StepInterfaceError: If invalid function arguments were passed.
Returns:
The artifacts, external artifacts and parameters for the step.
"""
signature = get_step_entrypoint_signature(step=self)
try:
bound_args = signature.bind_partial(*args, **kwargs)
except TypeError as e:
raise StepInterfaceError(
f"Wrong arguments when calling step '{self.name}': {e}"
) from e
artifacts = {}
external_artifacts = {}
parameters = {}
for key, value in bound_args.arguments.items():
self.entrypoint_definition.validate_input(key=key, value=value)
if isinstance(value, StepArtifact):
artifacts[key] = value
if key in self.configuration.parameters:
logger.warning(
"Got duplicate value for step input %s, using value "
"provided as artifact.",
key,
)
elif isinstance(value, ExternalArtifact):
external_artifacts[key] = value
if not value._id:
# If the external artifact references a fixed artifact by
# ID, caching behaves as expected.
logger.warning(
"Using an external artifact as step input currently "
"invalidates caching for the step and all downstream "
"steps. Future releases will introduce hashing of "
"artifacts which will improve this behavior."
)
else:
parameters[key] = value
# Above we iterated over the provided arguments which should overwrite
# any parameters previously defined on the step instance. Now we apply
# the default values on the entrypoint function and add those as
# parameters for any argument that has no value yet. If we were to do
# that in the above loop, we would overwrite previously configured
# parameters with the default values.
bound_args.apply_defaults()
for key, value in bound_args.arguments.items():
self.entrypoint_definition.validate_input(key=key, value=value)
if (
key not in artifacts
and key not in external_artifacts
and key not in self.configuration.parameters
):
parameters[key] = value
return artifacts, external_artifacts, parameters
def __call__(
self,
*args: Any,
id: Optional[str] = None,
after: Union[str, Sequence[str], None] = None,
**kwargs: Any,
) -> Any:
"""Handle a call of the step.
This method does one of two things:
* If there is an active pipeline context, it adds an invocation of the
step instance to the pipeline.
* If no pipeline is active, it calls the step entrypoint function.
Args:
*args: Entrypoint function arguments.
id: Invocation ID to use.
after: Upstream steps for the invocation.
**kwargs: Entrypoint function keyword arguments.
Returns:
The outputs of the entrypoint function call.
"""
from zenml.new.pipelines.pipeline import Pipeline
if not Pipeline.ACTIVE_PIPELINE:
# The step is being called outside of the context of a pipeline,
# we simply call the entrypoint
return self.call_entrypoint(*args, **kwargs)
(
input_artifacts,
external_artifacts,
parameters,
) = self._parse_call_args(*args, **kwargs)
upstream_steps = {
artifact.invocation_id for artifact in input_artifacts.values()
}
if isinstance(after, str):
upstream_steps.add(after)
elif isinstance(after, Sequence):
upstream_steps.union(after)
invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation(
step=self,
input_artifacts=input_artifacts,
external_artifacts=external_artifacts,
parameters=parameters,
upstream_steps=upstream_steps,
custom_id=id,
allow_id_suffix=not id,
)
outputs = []
for key, annotation in self.entrypoint_definition.outputs.items():
output = StepArtifact(
invocation_id=invocation_id,
output_name=key,
annotation=annotation,
pipeline=Pipeline.ACTIVE_PIPELINE,
)
outputs.append(output)
if len(outputs) == 1:
return outputs[0]
else:
return outputs
def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Calls the entrypoint function of the step.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Returns:
The return value of the entrypoint function.
Raises:
StepInterfaceError: If the arguments to the entrypoint function are
invalid.
"""
try:
validated_args = pydantic_utils.validate_function_args(
self.entrypoint,
{"arbitrary_types_allowed": True, "smart_union": True},
*args,
**kwargs,
)
except ValidationError as e:
raise StepInterfaceError("Invalid entrypoint arguments.") from e
return self.entrypoint(**validated_args)
@property
def name(self) -> str:
"""The name of the step.
Returns:
The name of the step.
"""
return self.configuration.name
@property
def enable_cache(self) -> Optional[bool]:
"""If caching is enabled for the step.
Returns:
If caching is enabled for the step.
"""
return self.configuration.enable_cache
@property
def configuration(self) -> PartialStepConfiguration:
"""The configuration of the step.
Returns:
The configuration of the step.
"""
return self._configuration
def configure(
self: T,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
merge: bool = True,
) -> T:
"""Configures the step.
Configuration merging example:
* `merge==True`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* `merge==False`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Args:
name: DEPRECATED: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with three possible parameters, `StepContext`,
`BaseParameters`, and `BaseException`, or a source path to a
function of the same specifications (e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can
be a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same
specifications (e.g. `module.my_function`).
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
Returns:
The step instance that this method was called on.
"""
from zenml.hooks.hook_validators import resolve_and_validate_hook
if name:
logger.warning("Configuring the name of a step is deprecated.")
def _resolve_if_necessary(
value: Union[str, Source, Type[Any]]
) -> Source:
if isinstance(value, str):
return Source.from_import_path(value)
elif isinstance(value, Source):
return value
else:
return source_utils.resolve(value)
def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
if isinstance(value, str) or not isinstance(value, Sequence):
return (_resolve_if_necessary(value),)
else:
return tuple(_resolve_if_necessary(v) for v in value)
outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
allowed_output_names = set(self.entrypoint_definition.outputs)
if output_materializers:
if not isinstance(output_materializers, Mapping):
sources = _convert_to_tuple(output_materializers)
output_materializers = {
output_name: sources
for output_name in allowed_output_names
}
for output_name, materializer in output_materializers.items():
sources = _convert_to_tuple(materializer)
outputs[output_name]["materializer_source"] = sources
failure_hook_source = None
if on_failure:
# string of on_failure hook function to be used for this step
failure_hook_source = resolve_and_validate_hook(on_failure)
success_hook_source = None
if on_success:
# string of on_success hook function to be used for this step
success_hook_source = resolve_and_validate_hook(on_success)
if isinstance(parameters, BaseParameters):
parameters = parameters.dict()
values = dict_utils.remove_none_values(
{
"enable_cache": enable_cache,
"enable_artifact_metadata": enable_artifact_metadata,
"enable_artifact_visualization": enable_artifact_visualization,
"experiment_tracker": experiment_tracker,
"step_operator": step_operator,
"parameters": parameters,
"settings": settings,
"outputs": outputs or None,
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config, merge=merge)
return self
def with_options(
self,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
merge: bool = True,
) -> "BaseStep":
"""Copies the step and applies the given configurations.
Args:
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with three possible parameters, `StepContext`,
`BaseParameters`, and `BaseException`, or a source path to a
function of the same specifications (e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can
be a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same
specifications (e.g. `module.my_function`).
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
Returns:
The copied step instance.
"""
step_copy = self.copy()
step_copy.configure(
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
experiment_tracker=experiment_tracker,
step_operator=step_operator,
parameters=parameters,
output_materializers=output_materializers,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
merge=merge,
)
return step_copy
def copy(self) -> "BaseStep":
"""Copies the step.
Returns:
The step copy.
"""
return copy.deepcopy(self)
def _apply_configuration(
self,
config: StepConfigurationUpdate,
merge: bool = True,
) -> None:
"""Applies an update to the step configuration.
Args:
config: The configuration update.
merge: Whether to merge the updates with the existing configuration
or not. See the `BaseStep.configure(...)` method for a detailed
explanation.
"""
self._validate_configuration(config)
self._configuration = pydantic_utils.update_model(
self._configuration, update=config, recursive=merge
)
logger.debug("Updated step configuration:")
logger.debug(self._configuration)
def _validate_configuration(self, config: StepConfigurationUpdate) -> None:
"""Validates a configuration update.
Args:
config: The configuration update to validate.
"""
settings_utils.validate_setting_keys(list(config.settings))
self._validate_function_parameters(parameters=config.parameters)
self._validate_outputs(outputs=config.outputs)
def _validate_function_parameters(
self, parameters: Dict[str, Any]
) -> None:
"""Validates step function parameters.
Args:
parameters: The parameters to validate.
Raises:
StepInterfaceError: If the step requires no function parameters but
parameters were configured.
"""
if not parameters:
return
for key, value in parameters.items():
if key in self.entrypoint_definition.inputs:
self.entrypoint_definition.validate_input(key=key, value=value)
elif not self.entrypoint_definition.legacy_params:
raise StepInterfaceError(
"Can't set parameter without param class."
)
def _validate_outputs(
self, outputs: Mapping[str, PartialArtifactConfiguration]
) -> None:
"""Validates the step output configuration.
Args:
outputs: The configured step outputs.
Raises:
StepInterfaceError: If an output for a non-existent name is
configured of an output artifact/materializer source does not
resolve to the correct class.
"""
allowed_output_names = set(self.entrypoint_definition.outputs)
for output_name, output in outputs.items():
if output_name not in allowed_output_names:
raise StepInterfaceError(
f"Got unexpected materializers for non-existent "
f"output '{output_name}' in step '{self.name}'. "
f"Only materializers for the outputs "
f"{allowed_output_names} of this step can"
f" be registered."
)
if output.materializer_source:
for source in output.materializer_source:
if not source_utils.validate_source_class(
source, expected_class=BaseMaterializer
):
raise StepInterfaceError(
f"Materializer source `{source}` "
f"for output '{output_name}' of step '{self.name}' "
"does not resolve to a `BaseMaterializer` subclass."
)
def _validate_inputs(
self,
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, UUID],
) -> None:
"""Validates the step inputs.
This method makes sure that all inputs are provided either as an
artifact or parameter.
Args:
input_artifacts: The input artifacts.
external_artifacts: The external input artifacts.
Raises:
StepInterfaceError: If an entrypoint input is missing.
"""
for key in self.entrypoint_definition.inputs.keys():
if (
key in input_artifacts
or key in self.configuration.parameters
or key in external_artifacts
):
continue
raise StepInterfaceError(f"Missing entrypoint input {key}.")
def _finalize_configuration(
self,
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, UUID],
) -> StepConfiguration:
"""Finalizes the configuration after the step was called.
Once the step was called, we know the outputs of previous steps
and that no additional user configurations will be made. That means
we can now collect the remaining artifact and materializer types
as well as check for the completeness of the step function parameters.
Args:
input_artifacts: The input artifacts of this step.
external_artifacts: The external artifacts of this step.
Returns:
The finalized step configuration.
"""
outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
for (
output_name,
output_annotation,
) in self.entrypoint_definition.outputs.items():
output = self._configuration.outputs.get(
output_name, PartialArtifactConfiguration()
)
from pydantic.typing import (
get_origin,
is_none_type,
is_union,
)
from zenml.materializers import CloudpickleMaterializer
from zenml.steps.utils import get_args
if not output.materializer_source:
if output_annotation is Any:
logger.warning(
f"No materializer specified for output with `Any` type "
f"annotation (output {output_name} of step {self.name} "
"). The Cloudpickle materializer will be used for the "
"artifact but the artifact won't be readable in "
"different Python versions. Please consider specifying "
"an explicit materializer for this output by following "
"this guide: https://docs.zenml.io/advanced-guide/pipelines/materializers."
)
outputs[output_name]["materializer_source"] = (
source_utils.resolve(CloudpickleMaterializer),
)
if is_union(
get_origin(output_annotation) or output_annotation
):
output_types = tuple(
type(None)
if is_none_type(output_type)
else output_type
for output_type in get_args(output_annotation)
)
else:
output_types = (output_annotation,)
materializer_sources = []
for output_type in output_types:
materializer_class = materializer_registry[output_type]
materializer_sources.append(
source_utils.resolve(materializer_class)
)
outputs[output_name]["materializer_source"] = tuple(
materializer_sources
)
parameters = self._finalize_parameters()
self.configure(parameters=parameters, merge=False)
self._validate_inputs(
input_artifacts=input_artifacts,
external_artifacts=external_artifacts,
)
values = dict_utils.remove_none_values({"outputs": outputs or None})
config = StepConfigurationUpdate(**values)
self._apply_configuration(config)
self._configuration = self._configuration.copy(
update={
"caching_parameters": self.caching_parameters,
"external_input_artifacts": external_artifacts,
}
)
complete_configuration = StepConfiguration.parse_obj(
self._configuration
)
return complete_configuration
def _finalize_parameters(self) -> Dict[str, Any]:
"""Finalizes the config parameters for running this step.
Returns:
All parameter values for running this step.
"""
params = {}
for key, value in self.configuration.parameters.items():
if key not in self.entrypoint_definition.inputs:
continue
annotation = self.entrypoint_definition.inputs[key].annotation
if inspect.isclass(annotation) and issubclass(
annotation, BaseModel
):
# Make sure we have all necessary values to instantiate the
# pydantic model later
model = annotation(**value)
params[key] = model.dict()
else:
params[key] = value
if self.entrypoint_definition.legacy_params:
legacy_params = self._finalize_legacy_parameters()
params[
self.entrypoint_definition.legacy_params.name
] = legacy_params
return params
def _finalize_legacy_parameters(self) -> Dict[str, Any]:
"""Verifies and prepares the config parameters for running this step.
When the step requires config parameters, this method:
- checks if config parameters were set via a config object or file
- tries to set missing config parameters from default values of the
config class
Returns:
Values for the previously unconfigured function parameters.
Raises:
MissingStepParameterError: If no value could be found for one or
more config parameters.
StepInterfaceError: If the parameter class validation failed.
"""
if not self.entrypoint_definition.legacy_params:
return {}
logger.warning(
"The `BaseParameters` class to define step parameters is "
"deprecated. Check out our docs "
"https://docs.zenml.io/user-guide/advanced-guide/configure-steps-pipelines "
"for information on how to parameterize your steps. As a quick "
"fix to get rid of this warning, make sure your parameter class "
"inherits from `pydantic.BaseModel` instead of the "
"`BaseParameters` class."
)
# parameters for the `BaseParameters` class specified in the "new" way
# by specifying a dict of parameters for the corresponding key
params_defined_in_new_way = (
self.configuration.parameters.get(
self.entrypoint_definition.legacy_params.name
)
or {}
)
values = {}
missing_keys = []
for (
name,
field,
) in (
self.entrypoint_definition.legacy_params.annotation.__fields__.items()
):
if name in self.configuration.parameters:
# a value for this parameter has been set already
values[name] = self.configuration.parameters[name]
elif name in params_defined_in_new_way:
# a value for this parameter has been set in the "new" way
# already
values[name] = params_defined_in_new_way[name]
elif field.required:
# this field has no default value set and therefore needs
# to be passed via an initialized config object
missing_keys.append(name)
else:
# use default value from the pydantic config class
values[name] = field.default
if missing_keys:
raise MissingStepParameterError(
self.name,
missing_keys,
self.entrypoint_definition.legacy_params.annotation,
)
if (
self.entrypoint_definition.legacy_params.annotation.__config__.extra
== Extra.allow
):
# Add all parameters for the config class for backwards
# compatibility if the config class allows extra attributes
values.update(self.configuration.parameters)
try:
self.entrypoint_definition.legacy_params.annotation(**values)
except ValidationError:
raise StepInterfaceError("Failed to validate function parameters.")
return values
caching_parameters: Dict[str, Any]
property
readonly
Caching parameters for this step.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the caching parameters |
configuration: PartialStepConfiguration
property
readonly
The configuration of the step.
Returns:
Type | Description |
---|---|
PartialStepConfiguration |
The configuration of the step. |
docstring: Optional[str]
property
readonly
The docstring of this step.
Returns:
Type | Description |
---|---|
Optional[str] |
The docstring of this step. |
enable_cache: Optional[bool]
property
readonly
If caching is enabled for the step.
Returns:
Type | Description |
---|---|
Optional[bool] |
If caching is enabled for the step. |
name: str
property
readonly
The name of the step.
Returns:
Type | Description |
---|---|
str |
The name of the step. |
source_code: str
property
readonly
The source code of this step.
Returns:
Type | Description |
---|---|
str |
The source code of this step. |
source_object: Any
property
readonly
The source object of this step.
Returns:
Type | Description |
---|---|
Any |
The source object of this step. |
upstream_steps: Set[BaseStep]
property
readonly
Names of the upstream steps of this step.
This property will only contain the full set of upstream steps once
it's parent pipeline connect(...)
method was called.
Returns:
Type | Description |
---|---|
Set[BaseStep] |
Set of upstream step names. |
__call__(self, *args, *, id=None, after=None, **kwargs)
special
Handle a call of the step.
This method does one of two things: * If there is an active pipeline context, it adds an invocation of the step instance to the pipeline. * If no pipeline is active, it calls the step entrypoint function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Entrypoint function arguments. |
() |
id |
Optional[str] |
Invocation ID to use. |
None |
after |
Union[str, Sequence[str]] |
Upstream steps for the invocation. |
None |
**kwargs |
Any |
Entrypoint function keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Any |
The outputs of the entrypoint function call. |
Source code in zenml/steps/base_step.py
def __call__(
self,
*args: Any,
id: Optional[str] = None,
after: Union[str, Sequence[str], None] = None,
**kwargs: Any,
) -> Any:
"""Handle a call of the step.
This method does one of two things:
* If there is an active pipeline context, it adds an invocation of the
step instance to the pipeline.
* If no pipeline is active, it calls the step entrypoint function.
Args:
*args: Entrypoint function arguments.
id: Invocation ID to use.
after: Upstream steps for the invocation.
**kwargs: Entrypoint function keyword arguments.
Returns:
The outputs of the entrypoint function call.
"""
from zenml.new.pipelines.pipeline import Pipeline
if not Pipeline.ACTIVE_PIPELINE:
# The step is being called outside of the context of a pipeline,
# we simply call the entrypoint
return self.call_entrypoint(*args, **kwargs)
(
input_artifacts,
external_artifacts,
parameters,
) = self._parse_call_args(*args, **kwargs)
upstream_steps = {
artifact.invocation_id for artifact in input_artifacts.values()
}
if isinstance(after, str):
upstream_steps.add(after)
elif isinstance(after, Sequence):
upstream_steps.union(after)
invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation(
step=self,
input_artifacts=input_artifacts,
external_artifacts=external_artifacts,
parameters=parameters,
upstream_steps=upstream_steps,
custom_id=id,
allow_id_suffix=not id,
)
outputs = []
for key, annotation in self.entrypoint_definition.outputs.items():
output = StepArtifact(
invocation_id=invocation_id,
output_name=key,
annotation=annotation,
pipeline=Pipeline.ACTIVE_PIPELINE,
)
outputs.append(output)
if len(outputs) == 1:
return outputs[0]
else:
return outputs
__init__(self, *args, *, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, **kwargs)
special
Initializes a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments passed to the step. |
() |
name |
Optional[str] |
The name of the step. |
None |
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[ParametersOrDict] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with three possible parameters, |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with two possible parameters, |
None |
**kwargs |
Any |
Keyword arguments passed to the step. |
{} |
Source code in zenml/steps/base_step.py
def __init__(
self,
*args: Any,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
**kwargs: Any,
) -> None:
"""Initializes a step.
Args:
*args: Positional arguments passed to the step.
name: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with three possible parameters, `StepContext`,
`BaseParameters`, and `BaseException`, or a source path to a
function of the same specifications (e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can
be a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same
specifications (e.g. `module.my_function`).
**kwargs: Keyword arguments passed to the step.
"""
self._upstream_steps: Set["BaseStep"] = set()
self.entrypoint_definition = validate_entrypoint_function(
self.entrypoint, reserved_arguments=["after", "id"]
)
name = name or self.__class__.__name__
requires_context = self.entrypoint_definition.context is not None
if enable_cache is None:
if requires_context:
# Using the StepContext inside a step provides access to
# external resources which might influence the step execution.
# We therefore disable caching unless it is explicitly enabled
enable_cache = False
logger.debug(
"Step '%s': Step context required and caching not "
"explicitly enabled.",
name,
)
logger.debug(
"Step '%s': Caching %s.",
name,
"enabled" if enable_cache is not False else "disabled",
)
logger.debug(
"Step '%s': Artifact metadata %s.",
name,
"enabled" if enable_artifact_metadata is not False else "disabled",
)
logger.debug(
"Step '%s': Artifact visualization %s.",
name,
"enabled"
if enable_artifact_visualization is not False
else "disabled",
)
self._configuration = PartialStepConfiguration(
name=name,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
)
self.configure(
experiment_tracker=experiment_tracker,
step_operator=step_operator,
output_materializers=output_materializers,
parameters=parameters,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
)
self._verify_and_apply_init_params(*args, **kwargs)
after(self, step)
Adds an upstream step to this step.
Calling this method makes sure this step only starts running once the given step has successfully finished executing.
Note: This can only be called inside the pipeline connect function
which is decorated with the @pipeline
decorator. Any calls outside
this function will be ignored.
Examples:
The following pipeline will run its steps sequentially in the following order: step_2 -> step_1 -> step_3
@pipeline
def example_pipeline(step_1, step_2, step_3):
step_1.after(step_2)
step_3(step_1(), step_2())
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
A step which should finish executing before this step is started. |
required |
Source code in zenml/steps/base_step.py
def after(self, step: "BaseStep") -> None:
"""Adds an upstream step to this step.
Calling this method makes sure this step only starts running once the
given step has successfully finished executing.
**Note**: This can only be called inside the pipeline connect function
which is decorated with the `@pipeline` decorator. Any calls outside
this function will be ignored.
Example:
The following pipeline will run its steps sequentially in the following
order: step_2 -> step_1 -> step_3
```python
@pipeline
def example_pipeline(step_1, step_2, step_3):
step_1.after(step_2)
step_3(step_1(), step_2())
```
Args:
step: A step which should finish executing before this step is
started.
"""
self._upstream_steps.add(step)
call_entrypoint(self, *args, **kwargs)
Calls the entrypoint function of the step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Entrypoint function arguments. |
() |
**kwargs |
Any |
Entrypoint function keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Any |
The return value of the entrypoint function. |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If the arguments to the entrypoint function are invalid. |
Source code in zenml/steps/base_step.py
def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Calls the entrypoint function of the step.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Returns:
The return value of the entrypoint function.
Raises:
StepInterfaceError: If the arguments to the entrypoint function are
invalid.
"""
try:
validated_args = pydantic_utils.validate_function_args(
self.entrypoint,
{"arbitrary_types_allowed": True, "smart_union": True},
*args,
**kwargs,
)
except ValidationError as e:
raise StepInterfaceError("Invalid entrypoint arguments.") from e
return self.entrypoint(**validated_args)
configure(self, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, merge=True)
Configures the step.
Configuration merging example:
* merge==True
:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* merge==False
:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
DEPRECATED: The name of the step. |
None |
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[ParametersOrDict] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with three possible parameters, |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with two possible parameters, |
None |
merge |
bool |
If |
True |
Returns:
Type | Description |
---|---|
~T |
The step instance that this method was called on. |
Source code in zenml/steps/base_step.py
def configure(
self: T,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
merge: bool = True,
) -> T:
"""Configures the step.
Configuration merging example:
* `merge==True`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* `merge==False`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Args:
name: DEPRECATED: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with three possible parameters, `StepContext`,
`BaseParameters`, and `BaseException`, or a source path to a
function of the same specifications (e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can
be a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same
specifications (e.g. `module.my_function`).
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
Returns:
The step instance that this method was called on.
"""
from zenml.hooks.hook_validators import resolve_and_validate_hook
if name:
logger.warning("Configuring the name of a step is deprecated.")
def _resolve_if_necessary(
value: Union[str, Source, Type[Any]]
) -> Source:
if isinstance(value, str):
return Source.from_import_path(value)
elif isinstance(value, Source):
return value
else:
return source_utils.resolve(value)
def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
if isinstance(value, str) or not isinstance(value, Sequence):
return (_resolve_if_necessary(value),)
else:
return tuple(_resolve_if_necessary(v) for v in value)
outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
allowed_output_names = set(self.entrypoint_definition.outputs)
if output_materializers:
if not isinstance(output_materializers, Mapping):
sources = _convert_to_tuple(output_materializers)
output_materializers = {
output_name: sources
for output_name in allowed_output_names
}
for output_name, materializer in output_materializers.items():
sources = _convert_to_tuple(materializer)
outputs[output_name]["materializer_source"] = sources
failure_hook_source = None
if on_failure:
# string of on_failure hook function to be used for this step
failure_hook_source = resolve_and_validate_hook(on_failure)
success_hook_source = None
if on_success:
# string of on_success hook function to be used for this step
success_hook_source = resolve_and_validate_hook(on_success)
if isinstance(parameters, BaseParameters):
parameters = parameters.dict()
values = dict_utils.remove_none_values(
{
"enable_cache": enable_cache,
"enable_artifact_metadata": enable_artifact_metadata,
"enable_artifact_visualization": enable_artifact_visualization,
"experiment_tracker": experiment_tracker,
"step_operator": step_operator,
"parameters": parameters,
"settings": settings,
"outputs": outputs or None,
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config, merge=merge)
return self
copy(self)
Copies the step.
Returns:
Type | Description |
---|---|
BaseStep |
The step copy. |
Source code in zenml/steps/base_step.py
def copy(self) -> "BaseStep":
"""Copies the step.
Returns:
The step copy.
"""
return copy.deepcopy(self)
entrypoint(self, *args, **kwargs)
Abstract method for core step logic.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments passed to the step. |
() |
**kwargs |
Any |
Keyword arguments passed to the step. |
{} |
Returns:
Type | Description |
---|---|
Any |
The output of the step. |
Source code in zenml/steps/base_step.py
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
Returns:
The output of the step.
"""
load_from_source(source)
classmethod
Loads a step from source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
Union[zenml.config.source.Source, str] |
The path to the step source. |
required |
Returns:
Type | Description |
---|---|
BaseStep |
The loaded step. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the source is not a valid step source. |
Source code in zenml/steps/base_step.py
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
"""Loads a step from source.
Args:
source: The path to the step source.
Returns:
The loaded step.
Raises:
ValueError: If the source is not a valid step source.
"""
obj = source_utils.load(source)
if isinstance(obj, BaseStep):
return obj
elif isinstance(obj, type) and issubclass(obj, BaseStep):
return obj()
else:
raise ValueError("Invalid step source.")
resolve(self)
Resolves the step.
Returns:
Type | Description |
---|---|
Source |
The step source. |
Source code in zenml/steps/base_step.py
def resolve(self) -> Source:
"""Resolves the step.
Returns:
The step source.
"""
return source_utils.resolve(self.__class__)
with_options(self, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, merge=True)
Copies the step and applies the given configurations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[ParametersOrDict] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with three possible parameters, |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with two possible parameters, |
None |
merge |
bool |
If |
True |
Returns:
Type | Description |
---|---|
BaseStep |
The copied step instance. |
Source code in zenml/steps/base_step.py
def with_options(
self,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional["ParametersOrDict"] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
merge: bool = True,
) -> "BaseStep":
"""Copies the step and applies the given configurations.
Args:
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with three possible parameters, `StepContext`,
`BaseParameters`, and `BaseException`, or a source path to a
function of the same specifications (e.g. `module.my_function`)
on_success: Callback function in event of failure of the step. Can
be a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same
specifications (e.g. `module.my_function`).
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
Returns:
The copied step instance.
"""
step_copy = self.copy()
step_copy.configure(
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
experiment_tracker=experiment_tracker,
step_operator=step_operator,
parameters=parameters,
output_materializers=output_materializers,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
merge=merge,
)
return step_copy
BaseStepMeta (type)
Metaclass for BaseStep
.
Makes sure that the entrypoint function has valid parameters and type annotations.
Source code in zenml/steps/base_step.py
class BaseStepMeta(type):
"""Metaclass for `BaseStep`.
Makes sure that the entrypoint function has valid parameters and type
annotations.
"""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
"""Set up a new class with a qualified spec.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The attributes of the class.
Returns:
The new class.
"""
cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
if name not in {"BaseStep", "_DecoratedStep"}:
validate_entrypoint_function(cls.entrypoint)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Set up a new class with a qualified spec.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class. |
required |
dct |
Dict[str, Any] |
The attributes of the class. |
required |
Returns:
Type | Description |
---|---|
BaseStepMeta |
The new class. |
Source code in zenml/steps/base_step.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
"""Set up a new class with a qualified spec.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The attributes of the class.
Returns:
The new class.
"""
cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
if name not in {"BaseStep", "_DecoratedStep"}:
validate_entrypoint_function(cls.entrypoint)
return cls
entrypoint_function_utils
Util functions for step and pipeline entrypoint functions.
EntrypointFunctionDefinition (tuple)
Class representing a step entrypoint function.
Attributes:
Name | Type | Description |
---|---|---|
inputs |
Dict[str, inspect.Parameter] |
The entrypoint function inputs. |
outputs |
Dict[str, Any] |
The entrypoint function outputs. This dictionary maps output names to output annotations. |
context |
Optional[inspect.Parameter] |
Optional parameter representing the |
legacy_params |
Optional[inspect.Parameter] |
Optional parameter representing the |
Source code in zenml/steps/entrypoint_function_utils.py
class EntrypointFunctionDefinition(NamedTuple):
"""Class representing a step entrypoint function.
Attributes:
inputs: The entrypoint function inputs.
outputs: The entrypoint function outputs. This dictionary maps output
names to output annotations.
context: Optional parameter representing the `StepContext` input.
legacy_params: Optional parameter representing the `BaseParameters`
input.
"""
inputs: Dict[str, inspect.Parameter]
outputs: Dict[str, Any]
context: Optional[inspect.Parameter]
legacy_params: Optional[inspect.Parameter]
def validate_input(self, key: str, value: Any) -> None:
"""Validates an input to the step entrypoint function.
Args:
key: The key for which the input was passed
value: The input value.
Raises:
KeyError: If the function has no input for the given key.
RuntimeError: If a parameter is passed for an input that is
annotated as an `UnmaterializedArtifact`.
StepInterfaceError: If the input is a parameter and not JSON
serializable.
"""
from zenml.materializers import UnmaterializedArtifact
if key not in self.inputs:
raise KeyError(
f"Received step entrypoint input for invalid key {key}."
)
parameter = self.inputs[key]
if isinstance(value, (StepArtifact, ExternalArtifact)):
# If we were to do any type validation for artifacts here, we
# would not be able to leverage pydantics type coercion (e.g.
# providing an `int` artifact for a `float` input)
return
# Not an artifact -> This is a parameter
if parameter.annotation is UnmaterializedArtifact:
raise RuntimeError(
"Passing parameter for input of type `UnmaterializedArtifact` "
"is not allowed."
)
self._validate_input_value(parameter=parameter, value=value)
if not yaml_utils.is_json_serializable(value):
raise StepInterfaceError(
f"Argument type (`{type(value)}`) for argument "
f"'{key}' is not JSON "
"serializable."
)
def _validate_input_value(
self, parameter: inspect.Parameter, value: Any
) -> None:
"""Validates an input value to the step entrypoint function.
Args:
parameter: The function parameter for which the value was provided.
value: The input value.
Raises:
RuntimeError: If the input value is not valid for the type
annotation provided for the function parameter.
"""
class ModelConfig(BaseConfig):
arbitrary_types_allowed = False
# Create a pydantic model with just a single required field with the
# type annotation of the parameter to verify the input type including
# pydantics type coercion
validation_model_class = create_model(
"input_validation_model",
__config__=ModelConfig,
value=(parameter.annotation, ...),
)
try:
validation_model_class(value=value)
except ValidationError as e:
raise RuntimeError("Input validation failed.") from e
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/steps/entrypoint_function_utils.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, inputs, outputs, context, legacy_params)
special
staticmethod
Create new instance of EntrypointFunctionDefinition(inputs, outputs, context, legacy_params)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/steps/entrypoint_function_utils.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
validate_input(self, key, value)
Validates an input to the step entrypoint function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key for which the input was passed |
required |
value |
Any |
The input value. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If the function has no input for the given key. |
RuntimeError |
If a parameter is passed for an input that is
annotated as an |
StepInterfaceError |
If the input is a parameter and not JSON serializable. |
Source code in zenml/steps/entrypoint_function_utils.py
def validate_input(self, key: str, value: Any) -> None:
"""Validates an input to the step entrypoint function.
Args:
key: The key for which the input was passed
value: The input value.
Raises:
KeyError: If the function has no input for the given key.
RuntimeError: If a parameter is passed for an input that is
annotated as an `UnmaterializedArtifact`.
StepInterfaceError: If the input is a parameter and not JSON
serializable.
"""
from zenml.materializers import UnmaterializedArtifact
if key not in self.inputs:
raise KeyError(
f"Received step entrypoint input for invalid key {key}."
)
parameter = self.inputs[key]
if isinstance(value, (StepArtifact, ExternalArtifact)):
# If we were to do any type validation for artifacts here, we
# would not be able to leverage pydantics type coercion (e.g.
# providing an `int` artifact for a `float` input)
return
# Not an artifact -> This is a parameter
if parameter.annotation is UnmaterializedArtifact:
raise RuntimeError(
"Passing parameter for input of type `UnmaterializedArtifact` "
"is not allowed."
)
self._validate_input_value(parameter=parameter, value=value)
if not yaml_utils.is_json_serializable(value):
raise StepInterfaceError(
f"Argument type (`{type(value)}`) for argument "
f"'{key}' is not JSON "
"serializable."
)
StepArtifact
Class to represent step output artifacts.
Source code in zenml/steps/entrypoint_function_utils.py
class StepArtifact:
"""Class to represent step output artifacts."""
def __init__(
self,
invocation_id: str,
output_name: str,
annotation: Any,
pipeline: "Pipeline",
) -> None:
"""Initialize a step artifact.
Args:
invocation_id: The ID of the invocation that produces this artifact.
output_name: The name of the output that produces this artifact.
annotation: The output type annotation.
pipeline: The pipeline which the invocation is part of.
"""
self.invocation_id = invocation_id
self.output_name = output_name
self.annotation = annotation
self.pipeline = pipeline
__init__(self, invocation_id, output_name, annotation, pipeline)
special
Initialize a step artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
invocation_id |
str |
The ID of the invocation that produces this artifact. |
required |
output_name |
str |
The name of the output that produces this artifact. |
required |
annotation |
Any |
The output type annotation. |
required |
pipeline |
Pipeline |
The pipeline which the invocation is part of. |
required |
Source code in zenml/steps/entrypoint_function_utils.py
def __init__(
self,
invocation_id: str,
output_name: str,
annotation: Any,
pipeline: "Pipeline",
) -> None:
"""Initialize a step artifact.
Args:
invocation_id: The ID of the invocation that produces this artifact.
output_name: The name of the output that produces this artifact.
annotation: The output type annotation.
pipeline: The pipeline which the invocation is part of.
"""
self.invocation_id = invocation_id
self.output_name = output_name
self.annotation = annotation
self.pipeline = pipeline
get_step_entrypoint_signature(step, include_step_context=False, include_legacy_parameters=False)
Get the entrypoint signature of a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
BaseStep |
The step for which to get the entrypoint signature. |
required |
include_step_context |
bool |
Whether to include the |
False |
include_legacy_parameters |
bool |
Whether to include the |
False |
Returns:
Type | Description |
---|---|
Signature |
The entrypoint function signature. |
Source code in zenml/steps/entrypoint_function_utils.py
def get_step_entrypoint_signature(
step: "BaseStep",
include_step_context: bool = False,
include_legacy_parameters: bool = False,
) -> inspect.Signature:
"""Get the entrypoint signature of a step.
Args:
step: The step for which to get the entrypoint signature.
include_step_context: Whether to include the `StepContext` as a
parameter of the returned signature. If `False`, a potential
signature parameter of type `StepContext` will be removed before
returning the signature.
include_legacy_parameters: Whether to include the `BaseParameters`
subclass as a parameter of the returned signature. If `False`, a
potential signature parameter of type `BaseParameters` will be
removed before returning the signature.
Returns:
The entrypoint function signature.
"""
from zenml.steps import BaseParameters, StepContext
signature = inspect.signature(step.entrypoint, follow_wrapped=True)
def _is_param_of_class(annotation: Any, class_: Type[Any]) -> bool:
return inspect.isclass(annotation) and issubclass(annotation, class_)
parameters = list(signature.parameters.values())
if not include_step_context:
parameters = [
param
for param in parameters
if not _is_param_of_class(param.annotation, class_=StepContext)
]
if not include_legacy_parameters:
parameters = [
param
for param in parameters
if not _is_param_of_class(param.annotation, class_=BaseParameters)
]
signature = signature.replace(parameters=parameters)
return signature
validate_entrypoint_function(func, reserved_arguments=())
Validates a step entrypoint function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., Any] |
The step entrypoint function to validate. |
required |
reserved_arguments |
Sequence[str] |
The reserved arguments for the entrypoint function. |
() |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If the entrypoint function has variable arguments or keyword arguments. |
StepInterfaceError |
If the entrypoint function has multiple
|
StepInterfaceError |
If the entrypoint function has multiple
|
StepInterfaceError |
If the entrypoint function has no return annotation. |
Returns:
Type | Description |
---|---|
EntrypointFunctionDefinition |
A validated definition of the entrypoint function. |
Source code in zenml/steps/entrypoint_function_utils.py
def validate_entrypoint_function(
func: Callable[..., Any], reserved_arguments: Sequence[str] = ()
) -> EntrypointFunctionDefinition:
"""Validates a step entrypoint function.
Args:
func: The step entrypoint function to validate.
reserved_arguments: The reserved arguments for the entrypoint function.
Raises:
StepInterfaceError: If the entrypoint function has variable arguments
or keyword arguments.
StepInterfaceError: If the entrypoint function has multiple
`BaseParameter` arguments.
StepInterfaceError: If the entrypoint function has multiple
`StepContext` arguments.
StepInterfaceError: If the entrypoint function has no return annotation.
Returns:
A validated definition of the entrypoint function.
"""
from zenml.steps import BaseParameters, StepContext
signature = inspect.signature(func, follow_wrapped=True)
validate_reserved_arguments(
signature=signature, reserved_arguments=reserved_arguments
)
inputs = {}
context: Optional[inspect.Parameter] = None
legacy_params: Optional[inspect.Parameter] = None
signature_parameters = list(signature.parameters.items())
if signature_parameters and signature_parameters[0][0] == "self":
# TODO: Once we get rid of the old step decorator, we can also remove
# the `BaseStepMeta` class which right now calls this function on an
# unbound instance method when using the class-based API. If we get rid
# of that, this check and removal of the `self` parameter is not
# necessary anymore
signature_parameters = signature_parameters[1:]
for key, parameter in signature_parameters:
if parameter.kind in {parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD}:
raise StepInterfaceError(
f"Variable args or kwargs not allowed for function "
f"{func.__name__}."
)
annotation = parameter.annotation
if annotation is parameter.empty:
# If a type annotation is missing, use `Any` instead
parameter = parameter.replace(annotation=Any)
if inspect.isclass(annotation) and issubclass(
annotation, BaseParameters
):
if legacy_params is not None:
raise StepInterfaceError(
f"Found multiple parameter arguments "
f"('{legacy_params.name}' and '{key}') "
f"for function {func.__name__}."
)
legacy_params = parameter
elif inspect.isclass(annotation) and issubclass(
annotation, StepContext
):
if context is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{context.name}' and '{key}') "
f"for function {func.__name__}."
)
context = parameter
else:
inputs[key] = parameter
if signature.return_annotation is signature.empty:
raise StepInterfaceError(
f"Missing return type annotation for function {func.__name__}."
)
outputs = parse_return_type_annotations(
return_annotation=signature.return_annotation
)
return EntrypointFunctionDefinition(
inputs=inputs,
outputs=outputs,
context=context,
legacy_params=legacy_params,
)
validate_reserved_arguments(signature, reserved_arguments)
Validates that the signature does not contain any reserved arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
signature |
Signature |
The signature to validate. |
required |
reserved_arguments |
Sequence[str] |
The reserved arguments for the signature. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the signature contains a reserved argument. |
Source code in zenml/steps/entrypoint_function_utils.py
def validate_reserved_arguments(
signature: inspect.Signature, reserved_arguments: Sequence[str]
) -> None:
"""Validates that the signature does not contain any reserved arguments.
Args:
signature: The signature to validate.
reserved_arguments: The reserved arguments for the signature.
Raises:
RuntimeError: If the signature contains a reserved argument.
"""
for arg in reserved_arguments:
if arg in signature.parameters:
raise RuntimeError(f"Reserved argument name '{arg}'.")
external_artifact
External artifact definition.
ExternalArtifact
External artifacts can be used to provide values as input to ZenML steps.
ZenML steps accept either artifacts (=outputs of other steps), parameters (raw, JSON serializable values) or external artifacts. External artifacts can be used to provide any value as input to a step without needing to write an additional step that returns this value.
Examples:
from zenml import step, pipeline, ExternalArtifact
import numpy as np
@step
def my_step(value: np.ndarray) -> None:
print(value)
my_array = np.array([1, 2, 3])
@pipeline
def my_pipeline():
my_step(value=ExternalArtifact(my_array))
Source code in zenml/steps/external_artifact.py
class ExternalArtifact:
"""External artifacts can be used to provide values as input to ZenML steps.
ZenML steps accept either artifacts (=outputs of other steps), parameters
(raw, JSON serializable values) or external artifacts. External artifacts
can be used to provide any value as input to a step without needing to
write an additional step that returns this value.
Example:
```
from zenml import step, pipeline, ExternalArtifact
import numpy as np
@step
def my_step(value: np.ndarray) -> None:
print(value)
my_array = np.array([1, 2, 3])
@pipeline
def my_pipeline():
my_step(value=ExternalArtifact(my_array))
```
"""
def __init__(
self,
value: Any = None,
id: Optional[UUID] = None,
materializer: Optional["MaterializerClassOrSource"] = None,
store_artifact_metadata: bool = True,
store_artifact_visualizations: bool = True,
) -> None:
"""Initializes an external artifact instance.
The external artifact needs to have either a value associated with it
that will be uploaded to the artifact store, or reference an artifact
that is already registered in ZenML. This could be either from a
previous pipeline run or a previously uploaded external artifact.
Args:
value: The artifact value. Either this or an artifact ID must be
provided.
id: The ID of an artifact that should be referenced by this external
artifact. Either this or an artifact value must be provided.
materializer: The materializer to use for saving the artifact value
to the artifact store. Only used when `value` is provided.
store_artifact_metadata: Whether metadata for the artifact should
be stored. Only used when `value` is provided.
store_artifact_visualizations: Whether visualizations for the
artifact should be stored. Only used when `value` is provided.
Raises:
ValueError: If no/multiple values are provided for the `value` and
`id` arguments.
"""
if value is not None and id is not None:
raise ValueError(
"Only a value or an ID can be provided when creating an "
"external artifact."
)
if value is None and id is None:
raise ValueError(
"Either a value or an ID must be provided when creating an "
"external artifact."
)
self._value = value
self._id = id
self._materializer = materializer
self._store_artifact_metadata = store_artifact_metadata
self._store_artifact_visualizations = store_artifact_visualizations
def upload_if_necessary(self) -> UUID:
"""Uploads the artifact if necessary.
This method does one of two things:
- If an artifact is referenced by ID, it will verify that the artifact
exists and is in the correct artifact store.
- Otherwise, the artifact value will be uploaded and published.
Raises:
RuntimeError: If the artifact store of the referenced artifact
is not the same as the one in the active stack.
RuntimeError: If the URI of the artifact already exists.
Returns:
The artifact ID.
"""
artifact_store_id = Client().active_stack.artifact_store.id
if self._id:
response = Client().get_artifact(artifact_id=self._id)
if response.artifact_store_id != artifact_store_id:
raise RuntimeError(
f"The artifact {response.name} (ID: {response.id}) "
"referenced by an external artifact is not stored in the "
"artifact store of the active stack. This will lead to "
"issues loading the artifact. Please make sure to only "
"reference artifacts stored in your active artifact store."
)
else:
assert self._value is not None
logger.info("Uploading external artifact...")
artifact_name = f"external_{uuid4()}"
materializer_class = self._get_materializer_class(
value=self._value
)
uri = os.path.join(
Client().active_stack.artifact_store.path,
"external_artifacts",
artifact_name,
)
if fileio.exists(uri):
raise RuntimeError(f"Artifact URI '{uri}' already exists.")
fileio.makedirs(uri)
materializer = materializer_class(uri)
artifact_id = artifact_utils.upload_artifact(
name=artifact_name,
data=self._value,
materializer=materializer,
artifact_store_id=artifact_store_id,
extract_metadata=self._store_artifact_metadata,
include_visualizations=self._store_artifact_visualizations,
)
# To avoid duplicate uploads, switch to referencing the uploaded
# artifact by ID
self._id = artifact_id
logger.info(
"Finished uploading external artifact %s.", artifact_id
)
return self._id
def _get_materializer_class(self, value: Any) -> Type["BaseMaterializer"]:
"""Gets a materializer class for a value.
If a custom materializer is defined for this artifact it will be
returned. Otherwise it will get the materializer class from the
registry, falling back to the Cloudpickle materializer if no concrete
materializer is registered for the type of value.
Args:
value: The value for which to get the materializer class.
Returns:
The materializer class.
"""
if isinstance(self._materializer, type):
return self._materializer
elif self._materializer:
return source_utils.load_and_validate_class(
self._materializer, expected_class=BaseMaterializer
)
else:
return materializer_registry[type(value)]
__init__(self, value=None, id=None, materializer=None, store_artifact_metadata=True, store_artifact_visualizations=True)
special
Initializes an external artifact instance.
The external artifact needs to have either a value associated with it that will be uploaded to the artifact store, or reference an artifact that is already registered in ZenML. This could be either from a previous pipeline run or a previously uploaded external artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
The artifact value. Either this or an artifact ID must be provided. |
None |
id |
Optional[uuid.UUID] |
The ID of an artifact that should be referenced by this external artifact. Either this or an artifact value must be provided. |
None |
materializer |
Optional[MaterializerClassOrSource] |
The materializer to use for saving the artifact value
to the artifact store. Only used when |
None |
store_artifact_metadata |
bool |
Whether metadata for the artifact should
be stored. Only used when |
True |
store_artifact_visualizations |
bool |
Whether visualizations for the
artifact should be stored. Only used when |
True |
Exceptions:
Type | Description |
---|---|
ValueError |
If no/multiple values are provided for the |
Source code in zenml/steps/external_artifact.py
def __init__(
self,
value: Any = None,
id: Optional[UUID] = None,
materializer: Optional["MaterializerClassOrSource"] = None,
store_artifact_metadata: bool = True,
store_artifact_visualizations: bool = True,
) -> None:
"""Initializes an external artifact instance.
The external artifact needs to have either a value associated with it
that will be uploaded to the artifact store, or reference an artifact
that is already registered in ZenML. This could be either from a
previous pipeline run or a previously uploaded external artifact.
Args:
value: The artifact value. Either this or an artifact ID must be
provided.
id: The ID of an artifact that should be referenced by this external
artifact. Either this or an artifact value must be provided.
materializer: The materializer to use for saving the artifact value
to the artifact store. Only used when `value` is provided.
store_artifact_metadata: Whether metadata for the artifact should
be stored. Only used when `value` is provided.
store_artifact_visualizations: Whether visualizations for the
artifact should be stored. Only used when `value` is provided.
Raises:
ValueError: If no/multiple values are provided for the `value` and
`id` arguments.
"""
if value is not None and id is not None:
raise ValueError(
"Only a value or an ID can be provided when creating an "
"external artifact."
)
if value is None and id is None:
raise ValueError(
"Either a value or an ID must be provided when creating an "
"external artifact."
)
self._value = value
self._id = id
self._materializer = materializer
self._store_artifact_metadata = store_artifact_metadata
self._store_artifact_visualizations = store_artifact_visualizations
upload_if_necessary(self)
Uploads the artifact if necessary.
This method does one of two things: - If an artifact is referenced by ID, it will verify that the artifact exists and is in the correct artifact store. - Otherwise, the artifact value will be uploaded and published.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the artifact store of the referenced artifact is not the same as the one in the active stack. |
RuntimeError |
If the URI of the artifact already exists. |
Returns:
Type | Description |
---|---|
UUID |
The artifact ID. |
Source code in zenml/steps/external_artifact.py
def upload_if_necessary(self) -> UUID:
"""Uploads the artifact if necessary.
This method does one of two things:
- If an artifact is referenced by ID, it will verify that the artifact
exists and is in the correct artifact store.
- Otherwise, the artifact value will be uploaded and published.
Raises:
RuntimeError: If the artifact store of the referenced artifact
is not the same as the one in the active stack.
RuntimeError: If the URI of the artifact already exists.
Returns:
The artifact ID.
"""
artifact_store_id = Client().active_stack.artifact_store.id
if self._id:
response = Client().get_artifact(artifact_id=self._id)
if response.artifact_store_id != artifact_store_id:
raise RuntimeError(
f"The artifact {response.name} (ID: {response.id}) "
"referenced by an external artifact is not stored in the "
"artifact store of the active stack. This will lead to "
"issues loading the artifact. Please make sure to only "
"reference artifacts stored in your active artifact store."
)
else:
assert self._value is not None
logger.info("Uploading external artifact...")
artifact_name = f"external_{uuid4()}"
materializer_class = self._get_materializer_class(
value=self._value
)
uri = os.path.join(
Client().active_stack.artifact_store.path,
"external_artifacts",
artifact_name,
)
if fileio.exists(uri):
raise RuntimeError(f"Artifact URI '{uri}' already exists.")
fileio.makedirs(uri)
materializer = materializer_class(uri)
artifact_id = artifact_utils.upload_artifact(
name=artifact_name,
data=self._value,
materializer=materializer,
artifact_store_id=artifact_store_id,
extract_metadata=self._store_artifact_metadata,
include_visualizations=self._store_artifact_visualizations,
)
# To avoid duplicate uploads, switch to referencing the uploaded
# artifact by ID
self._id = artifact_id
logger.info(
"Finished uploading external artifact %s.", artifact_id
)
return self._id
step_context
Step context class.
StepContext
Provides additional context inside a step function.
This class is used to access pipelines, materializers, and artifacts
inside a step function. To use it, add a StepContext
object
to the signature of your step function like this:
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
You do not need to create a StepContext
object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the StepContext
and automatically pass it when executing your
step.
Note: When using a StepContext
inside a step, ZenML disables caching
for this step by default as the context provides access to external
resources which might influence the result of your step execution. To
enable caching anyway, explicitly enable it in the @step
decorator or when
initializing your custom step class.
Source code in zenml/steps/step_context.py
class StepContext:
"""Provides additional context inside a step function.
This class is used to access pipelines, materializers, and artifacts
inside a step function. To use it, add a `StepContext` object
to the signature of your step function like this:
```python
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
```
You do not need to create a `StepContext` object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the `StepContext` and automatically pass it when executing your
step.
**Note**: When using a `StepContext` inside a step, ZenML disables caching
for this step by default as the context provides access to external
resources which might influence the result of your step execution. To
enable caching anyway, explicitly enable it in the `@step` decorator or when
initializing your custom step class.
"""
def __init__(
self,
step_name: str,
output_materializers: Mapping[str, Sequence[Type["BaseMaterializer"]]],
output_artifact_uris: Mapping[str, str],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifact_uris: The output artifacts of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifact_uris.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts URIs for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact URI "
f"keys: {set(output_artifact_uris)}"
)
self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifact_uris[key]
)
for key in output_materializers.keys()
}
self._stack = Client().active_stack
def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step does not have any outputs."
)
if not output_name and output_count > 1:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)
if output_name:
if output_name not in self._outputs:
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step '{self.step_name}'. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))
@property
def stack(self) -> Optional["Stack"]:
"""Returns the current active stack.
Returns:
The current active stack or None.
"""
return self._stack
@property
def pipeline_name(self) -> Optional[str]:
"""Returns the current pipeline name.
Returns:
The current pipeline name or None.
"""
env = Environment().step_environment
return env.pipeline_name
@property
def run_name(self) -> Optional[str]:
"""Returns the current run name.
Returns:
The current run name or None.
"""
env = Environment().step_environment
return env.run_name
@property
def parameters(self) -> Dict[str, Any]:
"""The step parameters.
Returns:
The step parameters.
"""
return self.step_run_info.config.parameters
@property
def step_run_info(self) -> "StepRunInfo":
"""Info about the currently running step.
Returns:
Info about the currently running step.
"""
env = Environment().step_environment
return env.step_run_info
@property
def cache_enabled(self) -> bool:
"""Returns whether cache is enabled for the step.
Returns:
True if cache is enabled for the step, otherwise False.
"""
env = Environment().step_environment
return env.cache_enabled
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
data_type: Optional[Type[Any]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
data_type: If the output annotation is of type `Union` and the step
therefore has multiple materializers configured, you can provide
a data type for the output which will be used to select the
correct materializer. If not provided, the first materializer
will be used.
Returns:
A materializer initialized with the output artifact for
the given output.
"""
from zenml.utils import materializer_utils
materializer_classes, artifact_uri = self._get_output(output_name)
if custom_materializer_class:
materializer_class = custom_materializer_class
elif len(materializer_classes) == 1 or not data_type:
materializer_class = materializer_classes[0]
else:
materializer_class = materializer_utils.select_materializer(
data_type=data_type, materializer_classes=materializer_classes
)
return materializer_class(artifact_uri)
def get_output_artifact_uri(
self, output_name: Optional[str] = None
) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
"""
return self._get_output(output_name).artifact_uri
cache_enabled: bool
property
readonly
Returns whether cache is enabled for the step.
Returns:
Type | Description |
---|---|
bool |
True if cache is enabled for the step, otherwise False. |
parameters: Dict[str, Any]
property
readonly
The step parameters.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The step parameters. |
pipeline_name: Optional[str]
property
readonly
Returns the current pipeline name.
Returns:
Type | Description |
---|---|
Optional[str] |
The current pipeline name or None. |
run_name: Optional[str]
property
readonly
Returns the current run name.
Returns:
Type | Description |
---|---|
Optional[str] |
The current run name or None. |
stack: Optional[Stack]
property
readonly
Returns the current active stack.
Returns:
Type | Description |
---|---|
Optional[Stack] |
The current active stack or None. |
step_run_info: StepRunInfo
property
readonly
Info about the currently running step.
Returns:
Type | Description |
---|---|
StepRunInfo |
Info about the currently running step. |
__init__(self, step_name, output_materializers, output_artifact_uris)
special
Initializes a StepContext instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step that this context is used in. |
required |
output_materializers |
Mapping[str, Sequence[Type[BaseMaterializer]]] |
The output materializers of the step that this context is used in. |
required |
output_artifact_uris |
Mapping[str, str] |
The output artifacts of the step that this context is used in. |
required |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the keys of the output materializers and output artifacts do not match. |
Source code in zenml/steps/step_context.py
def __init__(
self,
step_name: str,
output_materializers: Mapping[str, Sequence[Type["BaseMaterializer"]]],
output_artifact_uris: Mapping[str, str],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifact_uris: The output artifacts of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifact_uris.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts URIs for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact URI "
f"keys: {set(output_artifact_uris)}"
)
self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifact_uris[key]
)
for key in output_materializers.keys()
}
self._stack = Client().active_stack
get_output_artifact_uri(self, output_name=None)
Returns the artifact URI for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the URI. If no name is given and the step only has a single output, the URI of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
Returns:
Type | Description |
---|---|
str |
Artifact URI for the given output. |
Source code in zenml/steps/step_context.py
def get_output_artifact_uri(
self, output_name: Optional[str] = None
) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
"""
return self._get_output(output_name).artifact_uri
get_output_materializer(self, output_name=None, custom_materializer_class=None, data_type=None)
Returns a materializer for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the materializer. If no name is given and the step only has a single output, the materializer of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
custom_materializer_class |
Optional[Type[BaseMaterializer]] |
If given, this |
None |
data_type |
Optional[Type[Any]] |
If the output annotation is of type |
None |
Returns:
Type | Description |
---|---|
BaseMaterializer |
A materializer initialized with the output artifact for the given output. |
Source code in zenml/steps/step_context.py
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
data_type: Optional[Type[Any]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
data_type: If the output annotation is of type `Union` and the step
therefore has multiple materializers configured, you can provide
a data type for the output which will be used to select the
correct materializer. If not provided, the first materializer
will be used.
Returns:
A materializer initialized with the output artifact for
the given output.
"""
from zenml.utils import materializer_utils
materializer_classes, artifact_uri = self._get_output(output_name)
if custom_materializer_class:
materializer_class = custom_materializer_class
elif len(materializer_classes) == 1 or not data_type:
materializer_class = materializer_classes[0]
else:
materializer_class = materializer_utils.select_materializer(
data_type=data_type, materializer_classes=materializer_classes
)
return materializer_class(artifact_uri)
StepContextOutput (tuple)
Tuple containing materializer class and URI for a step output.
Source code in zenml/steps/step_context.py
class StepContextOutput(NamedTuple):
"""Tuple containing materializer class and URI for a step output."""
materializer_classes: Sequence[Type["BaseMaterializer"]]
artifact_uri: str
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/steps/step_context.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, materializer_classes, artifact_uri)
special
staticmethod
Create new instance of StepContextOutput(materializer_classes, artifact_uri)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/steps/step_context.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
step_decorator
Step decorator function.
step(_func=None, *, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, experiment_tracker=None, step_operator=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None)
Outer decorator function for the creation of a ZenML step.
In order to be able to work with parameters such as name
, it features a
nested decorator structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_func |
Optional[~F] |
The decorated function. |
None |
name |
Optional[str] |
The name of the step. If left empty, the name of the decorated function will be used as a fallback. |
None |
enable_cache |
Optional[bool] |
Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default unless the step
requires a |
None |
enable_artifact_metadata |
Optional[bool] |
Specify whether metadata is enabled for this step. If no value is passed, metadata is enabled by default. |
None |
enable_artifact_visualization |
Optional[bool] |
Specify whether visualization is enabled for this step. If no value is passed, visualization is enabled by default. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Dict[str, SettingsOrDict]] |
Settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be
a function with three possible parameters,
|
None |
on_success |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be
a function with two possible parameters, |
None |
Returns:
Type | Description |
---|---|
Union[Type[zenml.steps.base_step.BaseStep], Callable[[~F], Type[zenml.steps.base_step.BaseStep]]] |
The inner decorator which creates the step class based on the ZenML BaseStep |
Source code in zenml/steps/step_decorator.py
def step(
_func: Optional[F] = None,
*,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
output_materializers: Optional["OutputMaterializersSpecification"] = None,
settings: Optional[Dict[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
"""Outer decorator function for the creation of a ZenML step.
In order to be able to work with parameters such as `name`, it features a
nested decorator structure.
Args:
_func: The decorated function.
name: The name of the step. If left empty, the name of the decorated
function will be used as a fallback.
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default unless the step
requires a `StepContext` (see
`zenml.steps.step_context.StepContext` for more information).
enable_artifact_metadata: Specify whether metadata is enabled for this
step. If no value is passed, metadata is enabled by default.
enable_artifact_visualization: Specify whether visualization is enabled
for this step. If no value is passed, visualization is enabled by
default.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: Settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can be
a function with three possible parameters,
`StepContext`, `BaseParameters`, and `BaseException`,
or a source path to a function of the same specifications
(e.g. `module.my_function`).
on_success: Callback function in event of failure of the step. Can be
a function with two possible parameters, `StepContext` and
`BaseParameters, or a source path to a function of the same specifications
(e.g. `module.my_function`).
Returns:
The inner decorator which creates the step class based on the
ZenML BaseStep
"""
logger.warning(
"The `@step` decorator that you use to define your step is "
"deprecated. Check out our docs https://docs.zenml.io for "
"information on how to define steps in a more intuitive and "
"flexible way!"
)
def inner_decorator(func: F) -> Type[BaseStep]:
"""Inner decorator function for the creation of a ZenML Step.
Args:
func: types.FunctionType, this function will be used as the
"process" method of the generated Step.
Returns:
The class of a newly generated ZenML Step.
"""
return type( # noqa
func.__name__,
(_DecoratedStep,),
{
STEP_INNER_FUNC_NAME: staticmethod(func),
CLASS_CONFIGURATION: {
PARAM_STEP_NAME: name,
PARAM_ENABLE_CACHE: enable_cache,
PARAM_ENABLE_ARTIFACT_METADATA: enable_artifact_metadata,
PARAM_ENABLE_ARTIFACT_VISUALIZATION: enable_artifact_visualization,
PARAM_EXPERIMENT_TRACKER: experiment_tracker,
PARAM_STEP_OPERATOR: step_operator,
PARAM_OUTPUT_MATERIALIZERS: output_materializers,
PARAM_SETTINGS: settings,
PARAM_EXTRA_OPTIONS: extra,
PARAM_ON_FAILURE: on_failure,
PARAM_ON_SUCCESS: on_success,
},
"__module__": func.__module__,
"__doc__": func.__doc__,
},
)
if _func is None:
return inner_decorator
else:
return inner_decorator(_func)
step_environment
Step environment class.
StepEnvironment (BaseEnvironmentComponent)
Added information about a step runtime inside a step function.
This takes the form of an Environment component. This class can be used from within a pipeline step implementation to access additional information about the runtime parameters of a pipeline step, such as the pipeline name, pipeline run ID and other pipeline runtime information. To use it, access it inside your step function like this:
from zenml.environment import Environment
@step
def my_step(...)
env = Environment().step_environment
do_something_with(env.pipeline_name, env.run_name, env.step_name)
Source code in zenml/steps/step_environment.py
class StepEnvironment(BaseEnvironmentComponent):
"""Added information about a step runtime inside a step function.
This takes the form of an Environment component. This class can be used from
within a pipeline step implementation to access additional information about
the runtime parameters of a pipeline step, such as the pipeline name,
pipeline run ID and other pipeline runtime information. To use it, access it
inside your step function like this:
```python
from zenml.environment import Environment
@step
def my_step(...)
env = Environment().step_environment
do_something_with(env.pipeline_name, env.run_name, env.step_name)
```
"""
NAME = STEP_ENVIRONMENT_NAME
def __init__(
self,
step_run_info: "StepRunInfo",
cache_enabled: bool,
):
"""Initialize the environment of the currently running step.
Args:
step_run_info: Info about the currently running step.
cache_enabled: Whether caching is enabled for the current step run.
"""
super().__init__()
self._step_run_info = step_run_info
self._cache_enabled = cache_enabled
@property
def pipeline_name(self) -> str:
"""The name of the currently running pipeline.
Returns:
The name of the currently running pipeline.
"""
return self._step_run_info.pipeline.name
@property
def run_name(self) -> str:
"""The name of the current pipeline run.
Returns:
The name of the current pipeline run.
"""
return self._step_run_info.run_name
@property
def pipeline_run_id(self) -> str:
"""The ID of the current pipeline run.
Returns:
The ID of the current pipeline run.
"""
logger.warning(
"`StepContext.pipeline_run_id` is deprecated. Use "
"`StepContext.run_name` instead."
)
return self.run_name
@property
def step_name(self) -> str:
"""The name of the currently running step.
Returns:
The name of the currently running step.
"""
return self._step_run_info.pipeline_step_name
@property
def step_run_info(self) -> "StepRunInfo":
"""Info about the currently running step.
Returns:
Info about the currently running step.
"""
return self._step_run_info
@property
def cache_enabled(self) -> bool:
"""Returns whether cache is enabled for the step.
Returns:
True if cache is enabled for the step, otherwise False.
"""
return self._cache_enabled
cache_enabled: bool
property
readonly
Returns whether cache is enabled for the step.
Returns:
Type | Description |
---|---|
bool |
True if cache is enabled for the step, otherwise False. |
pipeline_name: str
property
readonly
The name of the currently running pipeline.
Returns:
Type | Description |
---|---|
str |
The name of the currently running pipeline. |
pipeline_run_id: str
property
readonly
The ID of the current pipeline run.
Returns:
Type | Description |
---|---|
str |
The ID of the current pipeline run. |
run_name: str
property
readonly
The name of the current pipeline run.
Returns:
Type | Description |
---|---|
str |
The name of the current pipeline run. |
step_name: str
property
readonly
The name of the currently running step.
Returns:
Type | Description |
---|---|
str |
The name of the currently running step. |
step_run_info: StepRunInfo
property
readonly
Info about the currently running step.
Returns:
Type | Description |
---|---|
StepRunInfo |
Info about the currently running step. |
__init__(self, step_run_info, cache_enabled)
special
Initialize the environment of the currently running step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_info |
StepRunInfo |
Info about the currently running step. |
required |
cache_enabled |
bool |
Whether caching is enabled for the current step run. |
required |
Source code in zenml/steps/step_environment.py
def __init__(
self,
step_run_info: "StepRunInfo",
cache_enabled: bool,
):
"""Initialize the environment of the currently running step.
Args:
step_run_info: Info about the currently running step.
cache_enabled: Whether caching is enabled for the current step run.
"""
super().__init__()
self._step_run_info = step_run_info
self._cache_enabled = cache_enabled
step_invocation
Step invocation class definition.
StepInvocation
Step invocation class.
Source code in zenml/steps/step_invocation.py
class StepInvocation:
"""Step invocation class."""
def __init__(
self,
id: str,
step: "BaseStep",
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifact"],
parameters: Dict[str, Any],
upstream_steps: Set[str],
pipeline: "Pipeline",
) -> None:
"""Initialize a step invocation.
Args:
id: The invocation ID.
step: The step that is represented by the invocation.
input_artifacts: The input artifacts for the invocation.
external_artifacts: The external artifacts for the invocation.
parameters: The parameters for the invocation.
upstream_steps: The upstream steps for the invocation.
pipeline: The parent pipeline of the invocation.
"""
self.id = id
self.step = step
self.input_artifacts = input_artifacts
self.external_artifacts = external_artifacts
self.parameters = parameters
self.invocation_upstream_steps = upstream_steps
self.pipeline = pipeline
@property
def upstream_steps(self) -> Set[str]:
"""The upstream steps of the invocation.
Returns:
The upstream steps of the invocation.
"""
return self.invocation_upstream_steps.union(
self._get_and_validate_step_upstream_steps()
)
def _get_and_validate_step_upstream_steps(self) -> Set[str]:
"""Validates the upstream steps defined on the step instance.
This is only allowed in legacy pipelines when calling `step.after(...)`
and we need to make sure that both the upstream and downstream steps
of such a relationship are only invoked once inside a pipeline.
Returns:
The upstream steps defined on the step instance.
"""
def _verify_single_invocation(step: "BaseStep") -> str:
invocations = {
invocation
for invocation in self.pipeline.invocations.values()
if invocation.step is step
}
if len(invocations) > 1:
raise RuntimeError(
"Setting upstream steps for a step using "
"`step_1.after(step_2)` is not allowed in combination "
"with calling one of the two steps multiple times."
)
return invocations.pop().id
if self.step.upstream_steps:
# If the step has upstream steps, make sure it only got invoked once
_verify_single_invocation(step=self.step)
upstream_steps = set()
for upstream_step in self.step.upstream_steps:
upstream_step_invocation_id = _verify_single_invocation(
step=upstream_step
)
upstream_steps.add(upstream_step_invocation_id)
return upstream_steps
def finalize(self) -> "StepConfiguration":
"""Finalizes a step invocation.
The will validate the upstream steps and run final configurations on the
step that is represented by the invocation.
Returns:
The finalized step configuration.
"""
# Validate the upstream steps for legacy .after() calls
self._get_and_validate_step_upstream_steps()
self.step.configure(parameters=self.parameters)
external_artifact_ids = {}
for key, artifact in self.external_artifacts.items():
external_artifact_ids[key] = artifact.upload_if_necessary()
return self.step._finalize_configuration(
input_artifacts=self.input_artifacts,
external_artifacts=external_artifact_ids,
)
upstream_steps: Set[str]
property
readonly
The upstream steps of the invocation.
Returns:
Type | Description |
---|---|
Set[str] |
The upstream steps of the invocation. |
__init__(self, id, step, input_artifacts, external_artifacts, parameters, upstream_steps, pipeline)
special
Initialize a step invocation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
str |
The invocation ID. |
required |
step |
BaseStep |
The step that is represented by the invocation. |
required |
input_artifacts |
Dict[str, StepArtifact] |
The input artifacts for the invocation. |
required |
external_artifacts |
Dict[str, ExternalArtifact] |
The external artifacts for the invocation. |
required |
parameters |
Dict[str, Any] |
The parameters for the invocation. |
required |
upstream_steps |
Set[str] |
The upstream steps for the invocation. |
required |
pipeline |
Pipeline |
The parent pipeline of the invocation. |
required |
Source code in zenml/steps/step_invocation.py
def __init__(
self,
id: str,
step: "BaseStep",
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifact"],
parameters: Dict[str, Any],
upstream_steps: Set[str],
pipeline: "Pipeline",
) -> None:
"""Initialize a step invocation.
Args:
id: The invocation ID.
step: The step that is represented by the invocation.
input_artifacts: The input artifacts for the invocation.
external_artifacts: The external artifacts for the invocation.
parameters: The parameters for the invocation.
upstream_steps: The upstream steps for the invocation.
pipeline: The parent pipeline of the invocation.
"""
self.id = id
self.step = step
self.input_artifacts = input_artifacts
self.external_artifacts = external_artifacts
self.parameters = parameters
self.invocation_upstream_steps = upstream_steps
self.pipeline = pipeline
finalize(self)
Finalizes a step invocation.
The will validate the upstream steps and run final configurations on the step that is represented by the invocation.
Returns:
Type | Description |
---|---|
StepConfiguration |
The finalized step configuration. |
Source code in zenml/steps/step_invocation.py
def finalize(self) -> "StepConfiguration":
"""Finalizes a step invocation.
The will validate the upstream steps and run final configurations on the
step that is represented by the invocation.
Returns:
The finalized step configuration.
"""
# Validate the upstream steps for legacy .after() calls
self._get_and_validate_step_upstream_steps()
self.step.configure(parameters=self.parameters)
external_artifact_ids = {}
for key, artifact in self.external_artifacts.items():
external_artifact_ids[key] = artifact.upload_if_necessary()
return self.step._finalize_configuration(
input_artifacts=self.input_artifacts,
external_artifacts=external_artifact_ids,
)
step_output
Step output class.
Output
A named tuple with a default name that cannot be overridden.
Source code in zenml/steps/step_output.py
class Output(object):
"""A named tuple with a default name that cannot be overridden."""
def __init__(self, **kwargs: Type[Any]):
"""Initializes the output.
Args:
**kwargs: The output values.
"""
# TODO [ENG-161]: do we even need the named tuple here or is
# a list of tuples (name, Type) sufficient?
self.outputs = NamedTuple("ZenOutput", **kwargs) # type: ignore[misc]
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
"""Yields a tuple of type (output_name, output_type).
Yields:
A tuple of type (output_name, output_type).
"""
yield from self.outputs.__annotations__.items()
__init__(self, **kwargs)
special
Initializes the output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Type[Any] |
The output values. |
{} |
Source code in zenml/steps/step_output.py
def __init__(self, **kwargs: Type[Any]):
"""Initializes the output.
Args:
**kwargs: The output values.
"""
# TODO [ENG-161]: do we even need the named tuple here or is
# a list of tuples (name, Type) sufficient?
self.outputs = NamedTuple("ZenOutput", **kwargs) # type: ignore[misc]
items(self)
Yields a tuple of type (output_name, output_type).
Yields:
Type | Description |
---|---|
Iterator[Tuple[str, Type[Any]]] |
A tuple of type (output_name, output_type). |
Source code in zenml/steps/step_output.py
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
"""Yields a tuple of type (output_name, output_type).
Yields:
A tuple of type (output_name, output_type).
"""
yield from self.outputs.__annotations__.items()
utils
Utility functions and classes to run ZenML steps.
get_args(obj)
Get arguments of a Union type annotation.
Examples:
get_args(Union[int, str]) == (int, str)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The annotation. |
required |
Returns:
Type | Description |
---|---|
Tuple[Any, ...] |
The args of the Union annotation. |
Source code in zenml/steps/utils.py
def get_args(obj: Any) -> Tuple[Any, ...]:
"""Get arguments of a Union type annotation.
Example:
`get_args(Union[int, str]) == (int, str)`
Args:
obj: The annotation.
Returns:
The args of the Union annotation.
"""
return tuple(
pydantic_typing.get_origin(v) or v
for v in pydantic_typing.get_args(obj)
)
parse_return_type_annotations(return_annotation)
Parse the returns of a step function into a dict of resolved types.
Called within BaseStepMeta.__new__()
to define cls.OUTPUT_SIGNATURE
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
return_annotation |
Any |
Return annotation of the step function. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Output signature of the new step class. |
Source code in zenml/steps/utils.py
def parse_return_type_annotations(return_annotation: Any) -> Dict[str, Any]:
"""Parse the returns of a step function into a dict of resolved types.
Called within `BaseStepMeta.__new__()` to define `cls.OUTPUT_SIGNATURE`.
Args:
return_annotation: Return annotation of the step function.
Returns:
Output signature of the new step class.
"""
if return_annotation is None:
return {}
# Cast simple output types to `Output`.
if not isinstance(return_annotation, Output):
return_annotation = Output(
**{SINGLE_RETURN_OUT_NAME: return_annotation}
)
# Resolve type annotations of all outputs and save in new dict.
output_signature = {
output_name: resolve_type_annotation(output_type)
for output_name, output_type in return_annotation.items()
}
return output_signature
resolve_type_annotation(obj)
Returns the non-generic class for generic aliases of the typing module.
If the input is no generic typing alias, the input itself is returned.
Example: if the input object is typing.Dict
, this method will return the
concrete class dict
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The object to resolve. |
required |
Returns:
Type | Description |
---|---|
Any |
The non-generic class for generic aliases of the typing module. |
Source code in zenml/steps/utils.py
def resolve_type_annotation(obj: Any) -> Any:
"""Returns the non-generic class for generic aliases of the typing module.
If the input is no generic typing alias, the input itself is returned.
Example: if the input object is `typing.Dict`, this method will return the
concrete class `dict`.
Args:
obj: The object to resolve.
Returns:
The non-generic class for generic aliases of the typing module.
"""
origin = pydantic_typing.get_origin(obj) or obj
if pydantic_typing.is_union(origin):
return obj
return origin