Steps
zenml.steps
special
Initializer for ZenML steps.
A step is a single piece or stage of a ZenML pipeline. Think of each step as being one of the nodes of a Directed Acyclic Graph (or DAG). Steps are responsible for one aspect of processing or interacting with the data / artifacts in the pipeline.
Conceptually, a Step is a discrete and independent part of a pipeline that is responsible for one particular aspect of data manipulation inside a ZenML pipeline.
Steps can be subclassed from the BaseStep
class, or used via our @step
decorator.
base_step
Base Step for ZenML.
BaseStep
Abstract base class for all ZenML steps.
Source code in zenml/steps/base_step.py
class BaseStep:
"""Abstract base class for all ZenML steps."""
def __init__(
self,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
retry: Optional[StepRetryConfig] = None,
) -> None:
"""Initializes a step.
Args:
name: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
enable_step_logs: Enable step logs for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with a single argument of type `BaseException`, or
a source path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: configuration of the model version in the Model Control Plane.
retry: Configuration for retrying the step in case of failure.
"""
from zenml.config.step_configurations import PartialStepConfiguration
self.entrypoint_definition = validate_entrypoint_function(
self.entrypoint, reserved_arguments=["after", "id"]
)
name = name or self.__class__.__name__
logger.debug(
"Step `%s`: Caching %s.",
name,
"enabled" if enable_cache is not False else "disabled",
)
logger.debug(
"Step `%s`: Artifact metadata %s.",
name,
"enabled" if enable_artifact_metadata is not False else "disabled",
)
logger.debug(
"Step `%s`: Artifact visualization %s.",
name,
"enabled"
if enable_artifact_visualization is not False
else "disabled",
)
logger.debug(
"Step `%s`: logs %s.",
name,
"enabled" if enable_step_logs is not False else "disabled",
)
if model is not None:
logger.debug(
"Step `%s`: Is in Model context %s.",
name,
{
"model": model.name,
"version": model.version,
},
)
self._configuration = PartialStepConfiguration(
name=name,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
enable_step_logs=enable_step_logs,
)
self.configure(
experiment_tracker=experiment_tracker,
step_operator=step_operator,
output_materializers=output_materializers,
parameters=parameters,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
retry=retry,
)
notebook_utils.try_to_save_notebook_cell_code(self.source_object)
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
Returns:
The output of the step.
"""
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
"""Loads a step from source.
Args:
source: The path to the step source.
Returns:
The loaded step.
Raises:
ValueError: If the source is not a valid step source.
"""
obj = source_utils.load(source)
if isinstance(obj, BaseStep):
return obj
elif isinstance(obj, type) and issubclass(obj, BaseStep):
return obj()
else:
raise ValueError("Invalid step source.")
def resolve(self) -> Source:
"""Resolves the step.
Returns:
The step source.
"""
return source_utils.resolve(self.__class__)
@property
def source_object(self) -> Any:
"""The source object of this step.
Returns:
The source object of this step.
"""
return self.__class__
@property
def source_code(self) -> str:
"""The source code of this step.
Returns:
The source code of this step.
"""
return inspect.getsource(self.source_object)
@property
def docstring(self) -> Optional[str]:
"""The docstring of this step.
Returns:
The docstring of this step.
"""
return self.__doc__
@property
def caching_parameters(self) -> Dict[str, Any]:
"""Caching parameters for this step.
Returns:
A dictionary containing the caching parameters
"""
parameters = {
CODE_HASH_PARAMETER_NAME: source_code_utils.get_hashed_source_code(
self.source_object
)
}
for name, output in self.configuration.outputs.items():
if output.materializer_source:
key = f"{name}_materializer_source"
hash_ = hashlib.md5() # nosec
for source in output.materializer_source:
materializer_class = source_utils.load(source)
code_hash = source_code_utils.get_hashed_source_code(
materializer_class
)
hash_.update(code_hash.encode())
parameters[key] = hash_.hexdigest()
return parameters
def _parse_call_args(
self, *args: Any, **kwargs: Any
) -> Tuple[
Dict[str, "StepArtifact"],
Dict[str, "ExternalArtifact"],
Dict[str, "ModelVersionDataLazyLoader"],
Dict[str, "ClientLazyLoader"],
Dict[str, Any],
Dict[str, Any],
]:
"""Parses the call args for the step entrypoint.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Raises:
StepInterfaceError: If invalid function arguments were passed.
Returns:
The artifacts, external artifacts, model version artifacts/metadata and parameters for the step.
"""
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.models.v2.core.artifact_version import (
LazyArtifactVersionResponse,
)
from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
signature = inspect.signature(self.entrypoint, follow_wrapped=True)
try:
bound_args = signature.bind_partial(*args, **kwargs)
except TypeError as e:
raise StepInterfaceError(
f"Wrong arguments when calling step '{self.name}': {e}"
) from e
artifacts = {}
external_artifacts = {}
model_artifacts_or_metadata = {}
client_lazy_loaders = {}
parameters = {}
default_parameters = {}
for key, value in bound_args.arguments.items():
self.entrypoint_definition.validate_input(key=key, value=value)
if isinstance(value, StepArtifact):
artifacts[key] = value
if key in self.configuration.parameters:
logger.warning(
"Got duplicate value for step input %s, using value "
"provided as artifact.",
key,
)
elif isinstance(value, ExternalArtifact):
external_artifacts[key] = value
if not value.id:
# If the external artifact references a fixed artifact by
# ID, caching behaves as expected.
logger.warning(
"Using an external artifact as step input currently "
"invalidates caching for the step and all downstream "
"steps. Future releases will introduce hashing of "
"artifacts which will improve this behavior."
)
elif isinstance(value, LazyArtifactVersionResponse):
model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
model_name=value.lazy_load_model_name,
model_version=value.lazy_load_model_version,
artifact_name=value.lazy_load_name,
artifact_version=value.lazy_load_version,
metadata_name=None,
)
elif isinstance(value, LazyRunMetadataResponse):
model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
model_name=value.lazy_load_model_name,
model_version=value.lazy_load_model_version,
artifact_name=value.lazy_load_artifact_name,
artifact_version=value.lazy_load_artifact_version,
metadata_name=value.lazy_load_metadata_name,
)
elif isinstance(value, ClientLazyLoader):
client_lazy_loaders[key] = value
else:
parameters[key] = value
# Above we iterated over the provided arguments which should overwrite
# any parameters previously defined on the step instance. Now we apply
# the default values on the entrypoint function and add those as
# parameters for any argument that has no value yet. If we were to do
# that in the above loop, we would overwrite previously configured
# parameters with the default values.
bound_args.apply_defaults()
for key, value in bound_args.arguments.items():
self.entrypoint_definition.validate_input(key=key, value=value)
if (
key not in artifacts
and key not in external_artifacts
and key not in model_artifacts_or_metadata
and key not in self.configuration.parameters
and key not in client_lazy_loaders
):
default_parameters[key] = value
return (
artifacts,
external_artifacts,
model_artifacts_or_metadata,
client_lazy_loaders,
parameters,
default_parameters,
)
def __call__(
self,
*args: Any,
id: Optional[str] = None,
after: Union[str, Sequence[str], None] = None,
**kwargs: Any,
) -> Any:
"""Handle a call of the step.
This method does one of two things:
* If there is an active pipeline context, it adds an invocation of the
step instance to the pipeline.
* If no pipeline is active, it calls the step entrypoint function.
Args:
*args: Entrypoint function arguments.
id: Invocation ID to use.
after: Upstream steps for the invocation.
**kwargs: Entrypoint function keyword arguments.
Returns:
The outputs of the entrypoint function call.
"""
from zenml.pipelines.pipeline_definition import Pipeline
if not Pipeline.ACTIVE_PIPELINE:
from zenml import constants, get_step_context
# If the environment variable was set to explicitly not run on the
# stack, we do that.
run_without_stack = handle_bool_env_var(
ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK, default=False
)
if run_without_stack:
return self.call_entrypoint(*args, **kwargs)
try:
get_step_context()
except RuntimeError:
pass
else:
# We're currently inside the execution of a different step
# -> We don't want to launch another single step pipeline here,
# but instead just call the step function
return self.call_entrypoint(*args, **kwargs)
if constants.SHOULD_PREVENT_PIPELINE_EXECUTION:
logger.info(
"Preventing execution of step '%s'.",
self.name,
)
return
return run_as_single_step_pipeline(self, *args, **kwargs)
(
input_artifacts,
external_artifacts,
model_artifacts_or_metadata,
client_lazy_loaders,
parameters,
default_parameters,
) = self._parse_call_args(*args, **kwargs)
upstream_steps = {
artifact.invocation_id for artifact in input_artifacts.values()
}
if isinstance(after, str):
upstream_steps.add(after)
elif isinstance(after, Sequence):
upstream_steps = upstream_steps.union(after)
invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation(
step=self,
input_artifacts=input_artifacts,
external_artifacts=external_artifacts,
model_artifacts_or_metadata=model_artifacts_or_metadata,
client_lazy_loaders=client_lazy_loaders,
parameters=parameters,
default_parameters=default_parameters,
upstream_steps=upstream_steps,
custom_id=id,
allow_id_suffix=not id,
)
outputs = []
for key, annotation in self.entrypoint_definition.outputs.items():
output = StepArtifact(
invocation_id=invocation_id,
output_name=key,
annotation=annotation,
pipeline=Pipeline.ACTIVE_PIPELINE,
)
outputs.append(output)
return outputs[0] if len(outputs) == 1 else outputs
def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Calls the entrypoint function of the step.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Returns:
The return value of the entrypoint function.
Raises:
StepInterfaceError: If the arguments to the entrypoint function are
invalid.
"""
try:
validated_args = pydantic_utils.validate_function_args(
self.entrypoint,
ConfigDict(arbitrary_types_allowed=True),
*args,
**kwargs,
)
except ValidationError as e:
raise StepInterfaceError(
"Invalid step function entrypoint arguments. Check out the "
"pydantic error above for more details."
) from e
return self.entrypoint(**validated_args)
@property
def name(self) -> str:
"""The name of the step.
Returns:
The name of the step.
"""
return self.configuration.name
@property
def enable_cache(self) -> Optional[bool]:
"""If caching is enabled for the step.
Returns:
If caching is enabled for the step.
"""
return self.configuration.enable_cache
@property
def configuration(self) -> "PartialStepConfiguration":
"""The configuration of the step.
Returns:
The configuration of the step.
"""
return self._configuration
def configure(
self: T,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
merge: bool = True,
retry: Optional[StepRetryConfig] = None,
) -> T:
"""Configures the step.
Configuration merging example:
* `merge==True`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* `merge==False`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Args:
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
enable_step_logs: If step logs should be enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with a single argument of type `BaseException`, or
a source path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: configuration of the model version in the Model Control Plane.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
retry: Configuration for retrying the step in case of failure.
Returns:
The step instance that this method was called on.
"""
from zenml.config.step_configurations import StepConfigurationUpdate
from zenml.hooks.hook_validators import resolve_and_validate_hook
def _resolve_if_necessary(
value: Union[str, Source, Type[Any]],
) -> Source:
if isinstance(value, str):
return Source.from_import_path(value)
elif isinstance(value, Source):
return value
else:
return source_utils.resolve(value)
def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
if isinstance(value, str) or not isinstance(value, Sequence):
return (_resolve_if_necessary(value),)
else:
return tuple(_resolve_if_necessary(v) for v in value)
outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
allowed_output_names = set(self.entrypoint_definition.outputs)
if output_materializers:
if not isinstance(output_materializers, Mapping):
sources = _convert_to_tuple(output_materializers)
output_materializers = {
output_name: sources
for output_name in allowed_output_names
}
for output_name, materializer in output_materializers.items():
sources = _convert_to_tuple(materializer)
outputs[output_name]["materializer_source"] = sources
failure_hook_source = None
if on_failure:
# string of on_failure hook function to be used for this step
failure_hook_source = resolve_and_validate_hook(on_failure)
success_hook_source = None
if on_success:
# string of on_success hook function to be used for this step
success_hook_source = resolve_and_validate_hook(on_success)
values = dict_utils.remove_none_values(
{
"enable_cache": enable_cache,
"enable_artifact_metadata": enable_artifact_metadata,
"enable_artifact_visualization": enable_artifact_visualization,
"enable_step_logs": enable_step_logs,
"experiment_tracker": experiment_tracker,
"step_operator": step_operator,
"parameters": parameters,
"settings": settings,
"outputs": outputs or None,
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
"model": model,
"retry": retry,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config, merge=merge)
return self
def with_options(
self,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
merge: bool = True,
) -> "BaseStep":
"""Copies the step and applies the given configurations.
Args:
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
enable_step_logs: If step logs should be enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with a single argument of type `BaseException`, or
a source path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: configuration of the model version in the Model Control Plane.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
Returns:
The copied step instance.
"""
step_copy = self.copy()
step_copy.configure(
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
enable_step_logs=enable_step_logs,
experiment_tracker=experiment_tracker,
step_operator=step_operator,
parameters=parameters,
output_materializers=output_materializers,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
merge=merge,
)
return step_copy
def copy(self) -> "BaseStep":
"""Copies the step.
Returns:
The step copy.
"""
return copy.deepcopy(self)
def _apply_configuration(
self,
config: "StepConfigurationUpdate",
merge: bool = True,
runtime_parameters: Dict[str, Any] = {},
) -> None:
"""Applies an update to the step configuration.
Args:
config: The configuration update.
runtime_parameters: Dictionary of parameters passed to a step from runtime
merge: Whether to merge the updates with the existing configuration
or not. See the `BaseStep.configure(...)` method for a detailed
explanation.
"""
self._validate_configuration(config, runtime_parameters)
self._configuration = pydantic_utils.update_model(
self._configuration, update=config, recursive=merge
)
logger.debug("Updated step configuration:")
logger.debug(self._configuration)
def _validate_configuration(
self,
config: "StepConfigurationUpdate",
runtime_parameters: Dict[str, Any],
) -> None:
"""Validates a configuration update.
Args:
config: The configuration update to validate.
runtime_parameters: Dictionary of parameters passed to a step from runtime
"""
settings_utils.validate_setting_keys(list(config.settings))
self._validate_function_parameters(
parameters=config.parameters, runtime_parameters=runtime_parameters
)
self._validate_outputs(outputs=config.outputs)
def _validate_function_parameters(
self,
parameters: Dict[str, Any],
runtime_parameters: Dict[str, Any],
) -> None:
"""Validates step function parameters.
Args:
parameters: The parameters to validate.
runtime_parameters: Dictionary of parameters passed to a step from runtime
Raises:
StepInterfaceError: If the step requires no function parameters but
parameters were configured.
RuntimeError: If the step has parameters configured differently in
configuration file and code.
"""
if not parameters:
return
conflicting_parameters = {}
for key, value in parameters.items():
if key in runtime_parameters:
runtime_value = runtime_parameters[key]
if runtime_value != value:
conflicting_parameters[key] = (value, runtime_value)
if key in self.entrypoint_definition.inputs:
self.entrypoint_definition.validate_input(key=key, value=value)
else:
raise StepInterfaceError(
f"Unable to find parameter '{key}' in step function "
"signature."
)
if conflicting_parameters:
is_plural = "s" if len(conflicting_parameters) > 1 else ""
msg = f"Configured parameter{is_plural} for the step '{self.name}' conflict{'' if not is_plural else 's'} with parameter{is_plural} passed in runtime:\n"
for key, values in conflicting_parameters.items():
msg += (
f"`{key}`: config=`{values[0]}` | runtime=`{values[1]}`\n"
)
msg += """This happens, if you define values for step parameters in configuration file and pass same parameters from the code. Example:
```
# config.yaml
steps:
step_name:
parameters:
param_name: value1
# pipeline.py
@pipeline
def pipeline_():
step_name(param_name="other_value")
```
To avoid this consider setting step parameters only in one place (config or code).
"""
raise RuntimeError(msg)
def _validate_outputs(
self, outputs: Mapping[str, "PartialArtifactConfiguration"]
) -> None:
"""Validates the step output configuration.
Args:
outputs: The configured step outputs.
Raises:
StepInterfaceError: If an output for a non-existent name is
configured of an output artifact/materializer source does not
resolve to the correct class.
"""
allowed_output_names = set(self.entrypoint_definition.outputs)
for output_name, output in outputs.items():
if output_name not in allowed_output_names:
raise StepInterfaceError(
f"Got unexpected materializers for non-existent "
f"output '{output_name}' in step '{self.name}'. "
f"Only materializers for the outputs "
f"{allowed_output_names} of this step can"
f" be registered."
)
if output.materializer_source:
for source in output.materializer_source:
if not source_utils.validate_source_class(
source, expected_class=BaseMaterializer
):
raise StepInterfaceError(
f"Materializer source `{source}` "
f"for output '{output_name}' of step '{self.name}' "
"does not resolve to a `BaseMaterializer` subclass."
)
def _validate_inputs(
self,
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifactConfiguration"],
model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"],
client_lazy_loaders: Dict[str, "ClientLazyLoader"],
) -> None:
"""Validates the step inputs.
This method makes sure that all inputs are provided either as an
artifact or parameter.
Args:
input_artifacts: The input artifacts.
external_artifacts: The external input artifacts.
model_artifacts_or_metadata: The model artifacts or metadata.
client_lazy_loaders: The client lazy loaders.
Raises:
StepInterfaceError: If an entrypoint input is missing.
"""
for key in self.entrypoint_definition.inputs.keys():
if (
key in input_artifacts
or key in self.configuration.parameters
or key in external_artifacts
or key in model_artifacts_or_metadata
or key in client_lazy_loaders
):
continue
raise StepInterfaceError(
f"Missing entrypoint input '{key}' in step '{self.name}'."
)
def _finalize_configuration(
self,
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifactConfiguration"],
model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"],
client_lazy_loaders: Dict[str, "ClientLazyLoader"],
) -> "StepConfiguration":
"""Finalizes the configuration after the step was called.
Once the step was called, we know the outputs of previous steps
and that no additional user configurations will be made. That means
we can now collect the remaining artifact and materializer types
as well as check for the completeness of the step function parameters.
Args:
input_artifacts: The input artifacts of this step.
external_artifacts: The external artifacts of this step.
model_artifacts_or_metadata: The model artifacts or metadata of
this step.
client_lazy_loaders: The client lazy loaders of this step.
Raises:
StepInterfaceError: If explicit materializers were specified for an
output but they do not work for the data type(s) defined by
the type annotation.
Returns:
The finalized step configuration.
"""
from zenml.config.step_configurations import (
PartialArtifactConfiguration,
StepConfiguration,
StepConfigurationUpdate,
)
outputs: Dict[str, Dict[str, Any]] = defaultdict(dict)
for (
output_name,
output_annotation,
) in self.entrypoint_definition.outputs.items():
output = self._configuration.outputs.get(
output_name, PartialArtifactConfiguration()
)
if artifact_config := output_annotation.artifact_config:
outputs[output_name]["artifact_config"] = artifact_config
if output.materializer_source:
# The materializer source was configured by the user. We
# validate that their configured materializer supports the
# output type. If the output annotation is a Union, we check
# that at least one of the specified materializers works with at
# least one of the types in the Union. If that's not the case,
# it would be a guaranteed failure at runtime and we fail early
# here.
if output_annotation.resolved_annotation is Any:
continue
materializer_classes: List[Type["BaseMaterializer"]] = [
source_utils.load(materializer_source)
for materializer_source in output.materializer_source
]
for data_type in output_annotation.get_output_types():
try:
materializer_utils.select_materializer(
data_type=data_type,
materializer_classes=materializer_classes,
)
break
except RuntimeError:
pass
else:
materializer_strings = [
materializer_source.import_path
for materializer_source in output.materializer_source
]
raise StepInterfaceError(
"Invalid materializers specified for output "
f"{output_name} of step {self.name}. None of the "
f"materializers ({materializer_strings}) are "
"able to save or load data of the type that is defined "
"for the output "
f"({output_annotation.resolved_annotation})."
)
else:
if output_annotation.resolved_annotation is Any:
outputs[output_name]["materializer_source"] = ()
outputs[output_name]["default_materializer_source"] = (
source_utils.resolve(
materializer_registry.get_default_materializer()
)
)
continue
materializer_sources = []
for output_type in output_annotation.get_output_types():
materializer_class = materializer_registry[output_type]
materializer_sources.append(
source_utils.resolve(materializer_class)
)
outputs[output_name]["materializer_source"] = tuple(
materializer_sources
)
parameters = self._finalize_parameters()
self.configure(parameters=parameters, merge=False)
self._validate_inputs(
input_artifacts=input_artifacts,
external_artifacts=external_artifacts,
model_artifacts_or_metadata=model_artifacts_or_metadata,
client_lazy_loaders=client_lazy_loaders,
)
values = dict_utils.remove_none_values({"outputs": outputs or None})
config = StepConfigurationUpdate(**values)
self._apply_configuration(config)
self._configuration = self._configuration.model_copy(
update={
"caching_parameters": self.caching_parameters,
"external_input_artifacts": external_artifacts,
"model_artifacts_or_metadata": model_artifacts_or_metadata,
"client_lazy_loaders": client_lazy_loaders,
}
)
return StepConfiguration.model_validate(
self._configuration.model_dump()
)
def _finalize_parameters(self) -> Dict[str, Any]:
"""Finalizes the config parameters for running this step.
Returns:
All parameter values for running this step.
"""
params = {}
for key, value in self.configuration.parameters.items():
if key not in self.entrypoint_definition.inputs:
continue
annotation = self.entrypoint_definition.inputs[key].annotation
annotation = resolve_type_annotation(annotation)
if inspect.isclass(annotation) and issubclass(
annotation, BaseModel
):
# Make sure we have all necessary values to instantiate the
# pydantic model later
model = annotation(**value)
params[key] = model.model_dump()
else:
params[key] = value
return params
caching_parameters: Dict[str, Any]
property
readonly
Caching parameters for this step.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the caching parameters |
configuration: PartialStepConfiguration
property
readonly
The configuration of the step.
Returns:
Type | Description |
---|---|
PartialStepConfiguration |
The configuration of the step. |
docstring: Optional[str]
property
readonly
The docstring of this step.
Returns:
Type | Description |
---|---|
Optional[str] |
The docstring of this step. |
enable_cache: Optional[bool]
property
readonly
If caching is enabled for the step.
Returns:
Type | Description |
---|---|
Optional[bool] |
If caching is enabled for the step. |
name: str
property
readonly
The name of the step.
Returns:
Type | Description |
---|---|
str |
The name of the step. |
source_code: str
property
readonly
The source code of this step.
Returns:
Type | Description |
---|---|
str |
The source code of this step. |
source_object: Any
property
readonly
The source object of this step.
Returns:
Type | Description |
---|---|
Any |
The source object of this step. |
__call__(self, *args, *, id=None, after=None, **kwargs)
special
Handle a call of the step.
This method does one of two things: * If there is an active pipeline context, it adds an invocation of the step instance to the pipeline. * If no pipeline is active, it calls the step entrypoint function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Entrypoint function arguments. |
() |
id |
Optional[str] |
Invocation ID to use. |
None |
after |
Union[str, Sequence[str]] |
Upstream steps for the invocation. |
None |
**kwargs |
Any |
Entrypoint function keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Any |
The outputs of the entrypoint function call. |
Source code in zenml/steps/base_step.py
def __call__(
self,
*args: Any,
id: Optional[str] = None,
after: Union[str, Sequence[str], None] = None,
**kwargs: Any,
) -> Any:
"""Handle a call of the step.
This method does one of two things:
* If there is an active pipeline context, it adds an invocation of the
step instance to the pipeline.
* If no pipeline is active, it calls the step entrypoint function.
Args:
*args: Entrypoint function arguments.
id: Invocation ID to use.
after: Upstream steps for the invocation.
**kwargs: Entrypoint function keyword arguments.
Returns:
The outputs of the entrypoint function call.
"""
from zenml.pipelines.pipeline_definition import Pipeline
if not Pipeline.ACTIVE_PIPELINE:
from zenml import constants, get_step_context
# If the environment variable was set to explicitly not run on the
# stack, we do that.
run_without_stack = handle_bool_env_var(
ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK, default=False
)
if run_without_stack:
return self.call_entrypoint(*args, **kwargs)
try:
get_step_context()
except RuntimeError:
pass
else:
# We're currently inside the execution of a different step
# -> We don't want to launch another single step pipeline here,
# but instead just call the step function
return self.call_entrypoint(*args, **kwargs)
if constants.SHOULD_PREVENT_PIPELINE_EXECUTION:
logger.info(
"Preventing execution of step '%s'.",
self.name,
)
return
return run_as_single_step_pipeline(self, *args, **kwargs)
(
input_artifacts,
external_artifacts,
model_artifacts_or_metadata,
client_lazy_loaders,
parameters,
default_parameters,
) = self._parse_call_args(*args, **kwargs)
upstream_steps = {
artifact.invocation_id for artifact in input_artifacts.values()
}
if isinstance(after, str):
upstream_steps.add(after)
elif isinstance(after, Sequence):
upstream_steps = upstream_steps.union(after)
invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation(
step=self,
input_artifacts=input_artifacts,
external_artifacts=external_artifacts,
model_artifacts_or_metadata=model_artifacts_or_metadata,
client_lazy_loaders=client_lazy_loaders,
parameters=parameters,
default_parameters=default_parameters,
upstream_steps=upstream_steps,
custom_id=id,
allow_id_suffix=not id,
)
outputs = []
for key, annotation in self.entrypoint_definition.outputs.items():
output = StepArtifact(
invocation_id=invocation_id,
output_name=key,
annotation=annotation,
pipeline=Pipeline.ACTIVE_PIPELINE,
)
outputs.append(output)
return outputs[0] if len(outputs) == 1 else outputs
__init__(self, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, model=None, retry=None)
special
Initializes a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
The name of the step. |
None |
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
enable_step_logs |
Optional[bool] |
Enable step logs for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[Dict[str, Any]] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with a single argument of type |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. |
None |
model |
Optional[Model] |
configuration of the model version in the Model Control Plane. |
None |
retry |
Optional[zenml.config.retry_config.StepRetryConfig] |
Configuration for retrying the step in case of failure. |
None |
Source code in zenml/steps/base_step.py
def __init__(
self,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
retry: Optional[StepRetryConfig] = None,
) -> None:
"""Initializes a step.
Args:
name: The name of the step.
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
enable_step_logs: Enable step logs for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with a single argument of type `BaseException`, or
a source path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: configuration of the model version in the Model Control Plane.
retry: Configuration for retrying the step in case of failure.
"""
from zenml.config.step_configurations import PartialStepConfiguration
self.entrypoint_definition = validate_entrypoint_function(
self.entrypoint, reserved_arguments=["after", "id"]
)
name = name or self.__class__.__name__
logger.debug(
"Step `%s`: Caching %s.",
name,
"enabled" if enable_cache is not False else "disabled",
)
logger.debug(
"Step `%s`: Artifact metadata %s.",
name,
"enabled" if enable_artifact_metadata is not False else "disabled",
)
logger.debug(
"Step `%s`: Artifact visualization %s.",
name,
"enabled"
if enable_artifact_visualization is not False
else "disabled",
)
logger.debug(
"Step `%s`: logs %s.",
name,
"enabled" if enable_step_logs is not False else "disabled",
)
if model is not None:
logger.debug(
"Step `%s`: Is in Model context %s.",
name,
{
"model": model.name,
"version": model.version,
},
)
self._configuration = PartialStepConfiguration(
name=name,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
enable_step_logs=enable_step_logs,
)
self.configure(
experiment_tracker=experiment_tracker,
step_operator=step_operator,
output_materializers=output_materializers,
parameters=parameters,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
retry=retry,
)
notebook_utils.try_to_save_notebook_cell_code(self.source_object)
call_entrypoint(self, *args, **kwargs)
Calls the entrypoint function of the step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Entrypoint function arguments. |
() |
**kwargs |
Any |
Entrypoint function keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Any |
The return value of the entrypoint function. |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If the arguments to the entrypoint function are invalid. |
Source code in zenml/steps/base_step.py
def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Calls the entrypoint function of the step.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Returns:
The return value of the entrypoint function.
Raises:
StepInterfaceError: If the arguments to the entrypoint function are
invalid.
"""
try:
validated_args = pydantic_utils.validate_function_args(
self.entrypoint,
ConfigDict(arbitrary_types_allowed=True),
*args,
**kwargs,
)
except ValidationError as e:
raise StepInterfaceError(
"Invalid step function entrypoint arguments. Check out the "
"pydantic error above for more details."
) from e
return self.entrypoint(**validated_args)
configure(self, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, model=None, merge=True, retry=None)
Configures the step.
Configuration merging example:
* merge==True
:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* merge==False
:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
enable_step_logs |
Optional[bool] |
If step logs should be enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[Dict[str, Any]] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with a single argument of type |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. |
None |
model |
Optional[Model] |
configuration of the model version in the Model Control Plane. |
None |
merge |
bool |
If |
True |
retry |
Optional[zenml.config.retry_config.StepRetryConfig] |
Configuration for retrying the step in case of failure. |
None |
Returns:
Type | Description |
---|---|
~T |
The step instance that this method was called on. |
Source code in zenml/steps/base_step.py
def configure(
self: T,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
merge: bool = True,
retry: Optional[StepRetryConfig] = None,
) -> T:
"""Configures the step.
Configuration merging example:
* `merge==True`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=True)
step.configuration.extra # {"key1": 1, "key2": 2}
* `merge==False`:
step.configure(extra={"key1": 1})
step.configure(extra={"key2": 2}, merge=False)
step.configuration.extra # {"key2": 2}
Args:
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
enable_step_logs: If step logs should be enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with a single argument of type `BaseException`, or
a source path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: configuration of the model version in the Model Control Plane.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
retry: Configuration for retrying the step in case of failure.
Returns:
The step instance that this method was called on.
"""
from zenml.config.step_configurations import StepConfigurationUpdate
from zenml.hooks.hook_validators import resolve_and_validate_hook
def _resolve_if_necessary(
value: Union[str, Source, Type[Any]],
) -> Source:
if isinstance(value, str):
return Source.from_import_path(value)
elif isinstance(value, Source):
return value
else:
return source_utils.resolve(value)
def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
if isinstance(value, str) or not isinstance(value, Sequence):
return (_resolve_if_necessary(value),)
else:
return tuple(_resolve_if_necessary(v) for v in value)
outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
allowed_output_names = set(self.entrypoint_definition.outputs)
if output_materializers:
if not isinstance(output_materializers, Mapping):
sources = _convert_to_tuple(output_materializers)
output_materializers = {
output_name: sources
for output_name in allowed_output_names
}
for output_name, materializer in output_materializers.items():
sources = _convert_to_tuple(materializer)
outputs[output_name]["materializer_source"] = sources
failure_hook_source = None
if on_failure:
# string of on_failure hook function to be used for this step
failure_hook_source = resolve_and_validate_hook(on_failure)
success_hook_source = None
if on_success:
# string of on_success hook function to be used for this step
success_hook_source = resolve_and_validate_hook(on_success)
values = dict_utils.remove_none_values(
{
"enable_cache": enable_cache,
"enable_artifact_metadata": enable_artifact_metadata,
"enable_artifact_visualization": enable_artifact_visualization,
"enable_step_logs": enable_step_logs,
"experiment_tracker": experiment_tracker,
"step_operator": step_operator,
"parameters": parameters,
"settings": settings,
"outputs": outputs or None,
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
"model": model,
"retry": retry,
}
)
config = StepConfigurationUpdate(**values)
self._apply_configuration(config, merge=merge)
return self
copy(self)
Copies the step.
Returns:
Type | Description |
---|---|
BaseStep |
The step copy. |
Source code in zenml/steps/base_step.py
def copy(self) -> "BaseStep":
"""Copies the step.
Returns:
The step copy.
"""
return copy.deepcopy(self)
entrypoint(self, *args, **kwargs)
Abstract method for core step logic.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments passed to the step. |
() |
**kwargs |
Any |
Keyword arguments passed to the step. |
{} |
Returns:
Type | Description |
---|---|
Any |
The output of the step. |
Source code in zenml/steps/base_step.py
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic.
Args:
*args: Positional arguments passed to the step.
**kwargs: Keyword arguments passed to the step.
Returns:
The output of the step.
"""
load_from_source(source)
classmethod
Loads a step from source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
Union[zenml.config.source.Source, str] |
The path to the step source. |
required |
Returns:
Type | Description |
---|---|
BaseStep |
The loaded step. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the source is not a valid step source. |
Source code in zenml/steps/base_step.py
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
"""Loads a step from source.
Args:
source: The path to the step source.
Returns:
The loaded step.
Raises:
ValueError: If the source is not a valid step source.
"""
obj = source_utils.load(source)
if isinstance(obj, BaseStep):
return obj
elif isinstance(obj, type) and issubclass(obj, BaseStep):
return obj()
else:
raise ValueError("Invalid step source.")
resolve(self)
Resolves the step.
Returns:
Type | Description |
---|---|
Source |
The step source. |
Source code in zenml/steps/base_step.py
def resolve(self) -> Source:
"""Resolves the step.
Returns:
The step source.
"""
return source_utils.resolve(self.__class__)
with_options(self, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, model=None, merge=True)
Copies the step and applies the given configurations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
enable_cache |
Optional[bool] |
If caching should be enabled for this step. |
None |
enable_artifact_metadata |
Optional[bool] |
If artifact metadata should be enabled for this step. |
None |
enable_artifact_visualization |
Optional[bool] |
If artifact visualization should be enabled for this step. |
None |
enable_step_logs |
Optional[bool] |
If step logs should be enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
parameters |
Optional[Dict[str, Any]] |
Function parameters for this step |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Mapping[str, SettingsOrDict]] |
settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can
be a function with a single argument of type |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. |
None |
model |
Optional[Model] |
configuration of the model version in the Model Control Plane. |
None |
merge |
bool |
If |
True |
Returns:
Type | Description |
---|---|
BaseStep |
The copied step instance. |
Source code in zenml/steps/base_step.py
def with_options(
self,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
output_materializers: Optional[
"OutputMaterializersSpecification"
] = None,
settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
merge: bool = True,
) -> "BaseStep":
"""Copies the step and applies the given configurations.
Args:
enable_cache: If caching should be enabled for this step.
enable_artifact_metadata: If artifact metadata should be enabled
for this step.
enable_artifact_visualization: If artifact visualization should be
enabled for this step.
enable_step_logs: If step logs should be enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
parameters: Function parameters for this step
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can
be a function with a single argument of type `BaseException`, or
a source path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: configuration of the model version in the Model Control Plane.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
overwrite all existing ones. See the general description of this
method for an example.
Returns:
The copied step instance.
"""
step_copy = self.copy()
step_copy.configure(
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
enable_step_logs=enable_step_logs,
experiment_tracker=experiment_tracker,
step_operator=step_operator,
parameters=parameters,
output_materializers=output_materializers,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
merge=merge,
)
return step_copy
decorated_step
Internal BaseStep subclass used by the step decorator.
entrypoint_function_utils
Util functions for step and pipeline entrypoint functions.
EntrypointFunctionDefinition (tuple)
Class representing a step entrypoint function.
Attributes:
Name | Type | Description |
---|---|---|
inputs |
Dict[str, inspect.Parameter] |
The entrypoint function inputs. |
outputs |
Dict[str, zenml.steps.utils.OutputSignature] |
The entrypoint function outputs. This dictionary maps output names to output annotations. |
Source code in zenml/steps/entrypoint_function_utils.py
class EntrypointFunctionDefinition(NamedTuple):
"""Class representing a step entrypoint function.
Attributes:
inputs: The entrypoint function inputs.
outputs: The entrypoint function outputs. This dictionary maps output
names to output annotations.
"""
inputs: Dict[str, inspect.Parameter]
outputs: Dict[str, OutputSignature]
def validate_input(self, key: str, value: Any) -> None:
"""Validates an input to the step entrypoint function.
Args:
key: The key for which the input was passed
value: The input value.
Raises:
KeyError: If the function has no input for the given key.
RuntimeError: If a parameter is passed for an input that is
annotated as an `UnmaterializedArtifact`.
RuntimeError: If the input value is not valid for the type
annotation provided for the function parameter.
StepInterfaceError: If the input is a parameter and not JSON
serializable.
"""
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.artifacts.unmaterialized_artifact import (
UnmaterializedArtifact,
)
from zenml.client_lazy_loader import ClientLazyLoader
from zenml.models import (
ArtifactVersionResponse,
RunMetadataResponse,
)
if key not in self.inputs:
raise KeyError(
f"Received step entrypoint input for invalid key {key}."
)
parameter = self.inputs[key]
if isinstance(
value,
(
StepArtifact,
ExternalArtifact,
ArtifactVersionResponse,
RunMetadataResponse,
ClientLazyLoader,
),
):
# If we were to do any type validation for artifacts here, we
# would not be able to leverage pydantics type coercion (e.g.
# providing an `int` artifact for a `float` input)
return
# Not an artifact -> This is a parameter
if parameter.annotation is UnmaterializedArtifact:
raise RuntimeError(
"Passing parameter for input of type `UnmaterializedArtifact` "
"is not allowed."
)
if not yaml_utils.is_json_serializable(value):
raise StepInterfaceError(
f"Argument type (`{type(value)}`) for argument "
f"'{key}' is not JSON serializable and can not be passed as "
"a parameter. This input can either be provided by the "
"output of another step or as an external artifact: "
"https://docs.zenml.io/user-guide/starter-guide/manage-artifacts#managing-artifacts-not-produced-by-zenml-pipelines"
)
try:
self._validate_input_value(parameter=parameter, value=value)
except ValidationError as e:
raise RuntimeError(
f"Input validation failed for input '{parameter.name}': "
f"Expected type `{parameter.annotation}` but received type "
f"`{type(value)}`."
) from e
def _validate_input_value(
self, parameter: inspect.Parameter, value: Any
) -> None:
"""Validates an input value to the step entrypoint function.
Args:
parameter: The function parameter for which the value was provided.
value: The input value.
"""
config_dict = ConfigDict(arbitrary_types_allowed=False)
# Create a pydantic model with just a single required field with the
# type annotation of the parameter to verify the input type including
# pydantics type coercion
validation_model_class = create_model(
"input_validation_model",
__config__=config_dict,
value=(parameter.annotation, ...),
)
validation_model_class(value=value)
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/steps/entrypoint_function_utils.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, inputs, outputs)
special
staticmethod
Create new instance of EntrypointFunctionDefinition(inputs, outputs)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/steps/entrypoint_function_utils.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
validate_input(self, key, value)
Validates an input to the step entrypoint function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key for which the input was passed |
required |
value |
Any |
The input value. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If the function has no input for the given key. |
RuntimeError |
If a parameter is passed for an input that is
annotated as an |
RuntimeError |
If the input value is not valid for the type annotation provided for the function parameter. |
StepInterfaceError |
If the input is a parameter and not JSON serializable. |
Source code in zenml/steps/entrypoint_function_utils.py
def validate_input(self, key: str, value: Any) -> None:
"""Validates an input to the step entrypoint function.
Args:
key: The key for which the input was passed
value: The input value.
Raises:
KeyError: If the function has no input for the given key.
RuntimeError: If a parameter is passed for an input that is
annotated as an `UnmaterializedArtifact`.
RuntimeError: If the input value is not valid for the type
annotation provided for the function parameter.
StepInterfaceError: If the input is a parameter and not JSON
serializable.
"""
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.artifacts.unmaterialized_artifact import (
UnmaterializedArtifact,
)
from zenml.client_lazy_loader import ClientLazyLoader
from zenml.models import (
ArtifactVersionResponse,
RunMetadataResponse,
)
if key not in self.inputs:
raise KeyError(
f"Received step entrypoint input for invalid key {key}."
)
parameter = self.inputs[key]
if isinstance(
value,
(
StepArtifact,
ExternalArtifact,
ArtifactVersionResponse,
RunMetadataResponse,
ClientLazyLoader,
),
):
# If we were to do any type validation for artifacts here, we
# would not be able to leverage pydantics type coercion (e.g.
# providing an `int` artifact for a `float` input)
return
# Not an artifact -> This is a parameter
if parameter.annotation is UnmaterializedArtifact:
raise RuntimeError(
"Passing parameter for input of type `UnmaterializedArtifact` "
"is not allowed."
)
if not yaml_utils.is_json_serializable(value):
raise StepInterfaceError(
f"Argument type (`{type(value)}`) for argument "
f"'{key}' is not JSON serializable and can not be passed as "
"a parameter. This input can either be provided by the "
"output of another step or as an external artifact: "
"https://docs.zenml.io/user-guide/starter-guide/manage-artifacts#managing-artifacts-not-produced-by-zenml-pipelines"
)
try:
self._validate_input_value(parameter=parameter, value=value)
except ValidationError as e:
raise RuntimeError(
f"Input validation failed for input '{parameter.name}': "
f"Expected type `{parameter.annotation}` but received type "
f"`{type(value)}`."
) from e
StepArtifact
Class to represent step output artifacts.
Source code in zenml/steps/entrypoint_function_utils.py
class StepArtifact:
"""Class to represent step output artifacts."""
def __init__(
self,
invocation_id: str,
output_name: str,
annotation: Any,
pipeline: "Pipeline",
) -> None:
"""Initialize a step artifact.
Args:
invocation_id: The ID of the invocation that produces this artifact.
output_name: The name of the output that produces this artifact.
annotation: The output type annotation.
pipeline: The pipeline which the invocation is part of.
"""
self.invocation_id = invocation_id
self.output_name = output_name
self.annotation = annotation
self.pipeline = pipeline
def __iter__(self) -> NoReturn:
"""Raise a custom error if someone is trying to iterate this object.
Raises:
StepInterfaceError: If trying to iterate this object.
"""
raise StepInterfaceError(
"Unable to unpack step artifact. This error is probably because "
"you're trying to unpack the return value of your step but the "
"step only returns a single artifact. For more information on how "
"to add type annotations to your step to indicate multiple "
"artifacts visit https://docs.zenml.io/how-to/build-pipelines/step-output-typing-and-annotation#type-annotations."
)
__init__(self, invocation_id, output_name, annotation, pipeline)
special
Initialize a step artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
invocation_id |
str |
The ID of the invocation that produces this artifact. |
required |
output_name |
str |
The name of the output that produces this artifact. |
required |
annotation |
Any |
The output type annotation. |
required |
pipeline |
Pipeline |
The pipeline which the invocation is part of. |
required |
Source code in zenml/steps/entrypoint_function_utils.py
def __init__(
self,
invocation_id: str,
output_name: str,
annotation: Any,
pipeline: "Pipeline",
) -> None:
"""Initialize a step artifact.
Args:
invocation_id: The ID of the invocation that produces this artifact.
output_name: The name of the output that produces this artifact.
annotation: The output type annotation.
pipeline: The pipeline which the invocation is part of.
"""
self.invocation_id = invocation_id
self.output_name = output_name
self.annotation = annotation
self.pipeline = pipeline
__iter__(self)
special
Raise a custom error if someone is trying to iterate this object.
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If trying to iterate this object. |
Source code in zenml/steps/entrypoint_function_utils.py
def __iter__(self) -> NoReturn:
"""Raise a custom error if someone is trying to iterate this object.
Raises:
StepInterfaceError: If trying to iterate this object.
"""
raise StepInterfaceError(
"Unable to unpack step artifact. This error is probably because "
"you're trying to unpack the return value of your step but the "
"step only returns a single artifact. For more information on how "
"to add type annotations to your step to indicate multiple "
"artifacts visit https://docs.zenml.io/how-to/build-pipelines/step-output-typing-and-annotation#type-annotations."
)
validate_entrypoint_function(func, reserved_arguments=())
Validates a step entrypoint function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., Any] |
The step entrypoint function to validate. |
required |
reserved_arguments |
Sequence[str] |
The reserved arguments for the entrypoint function. |
() |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If the entrypoint function has variable arguments or keyword arguments. |
RuntimeError |
If type annotations should be enforced and a type annotation is missing. |
Returns:
Type | Description |
---|---|
EntrypointFunctionDefinition |
A validated definition of the entrypoint function. |
Source code in zenml/steps/entrypoint_function_utils.py
def validate_entrypoint_function(
func: Callable[..., Any], reserved_arguments: Sequence[str] = ()
) -> EntrypointFunctionDefinition:
"""Validates a step entrypoint function.
Args:
func: The step entrypoint function to validate.
reserved_arguments: The reserved arguments for the entrypoint function.
Raises:
StepInterfaceError: If the entrypoint function has variable arguments
or keyword arguments.
RuntimeError: If type annotations should be enforced and a type
annotation is missing.
Returns:
A validated definition of the entrypoint function.
"""
signature = inspect.signature(func, follow_wrapped=True)
validate_reserved_arguments(
signature=signature, reserved_arguments=reserved_arguments
)
inputs = {}
signature_parameters = list(signature.parameters.items())
for key, parameter in signature_parameters:
if parameter.kind in {parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD}:
raise StepInterfaceError(
f"Variable args or kwargs not allowed for function "
f"{func.__name__}."
)
annotation = parameter.annotation
if annotation is parameter.empty:
if ENFORCE_TYPE_ANNOTATIONS:
raise RuntimeError(
f"Missing type annotation for input '{key}' of step "
f"function '{func.__name__}'."
)
# If a type annotation is missing, use `Any` instead
parameter = parameter.replace(annotation=Any)
annotation = resolve_type_annotation(annotation)
inputs[key] = parameter
outputs = parse_return_type_annotations(
func=func, enforce_type_annotations=ENFORCE_TYPE_ANNOTATIONS
)
return EntrypointFunctionDefinition(
inputs=inputs,
outputs=outputs,
)
validate_reserved_arguments(signature, reserved_arguments)
Validates that the signature does not contain any reserved arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
signature |
Signature |
The signature to validate. |
required |
reserved_arguments |
Sequence[str] |
The reserved arguments for the signature. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the signature contains a reserved argument. |
Source code in zenml/steps/entrypoint_function_utils.py
def validate_reserved_arguments(
signature: inspect.Signature, reserved_arguments: Sequence[str]
) -> None:
"""Validates that the signature does not contain any reserved arguments.
Args:
signature: The signature to validate.
reserved_arguments: The reserved arguments for the signature.
Raises:
RuntimeError: If the signature contains a reserved argument.
"""
for arg in reserved_arguments:
if arg in signature.parameters:
raise RuntimeError(f"Reserved argument name '{arg}'.")
step_context
Step context class.
StepContext
Provides additional context inside a step function.
This singleton class is used to access information about the current run, step run, or its outputs inside a step function.
Usage example:
from zenml.steps import get_step_context
@step
def my_trainer_step() -> Any:
context = get_step_context()
# get info about the current pipeline run
current_pipeline_run = context.pipeline_run
# get info about the current step run
current_step_run = context.step_run
# get info about the future output artifacts of this step
output_artifact_uri = context.get_output_artifact_uri()
...
Source code in zenml/steps/step_context.py
class StepContext(metaclass=SingletonMetaClass):
"""Provides additional context inside a step function.
This singleton class is used to access information about the current run,
step run, or its outputs inside a step function.
Usage example:
```python
from zenml.steps import get_step_context
@step
def my_trainer_step() -> Any:
context = get_step_context()
# get info about the current pipeline run
current_pipeline_run = context.pipeline_run
# get info about the current step run
current_step_run = context.step_run
# get info about the future output artifacts of this step
output_artifact_uri = context.get_output_artifact_uri()
...
```
"""
def __init__(
self,
pipeline_run: "PipelineRunResponse",
step_run: "StepRunResponse",
output_materializers: Mapping[str, Sequence[Type["BaseMaterializer"]]],
output_artifact_uris: Mapping[str, str],
output_artifact_configs: Mapping[str, Optional["ArtifactConfig"]],
) -> None:
"""Initialize the context of the currently running step.
Args:
pipeline_run: The model of the current pipeline run.
step_run: The model of the current step run.
output_materializers: The output materializers of the step that
this context is used in.
output_artifact_uris: The output artifacts of the step that this
context is used in.
output_artifact_configs: The outputs' ArtifactConfigs of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
from zenml.client import Client
try:
pipeline_run = Client().get_pipeline_run(pipeline_run.id)
except KeyError:
pass
self.pipeline_run = pipeline_run
try:
step_run = Client().get_run_step(step_run.id)
except KeyError:
pass
self.step_run = step_run
self.model_version = (
step_run.model_version or pipeline_run.model_version
)
# Get the stack that we are running in
self._stack = Client().active_stack
self.step_name = self.step_run.name
# set outputs
if output_materializers.keys() != output_artifact_uris.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output artifact "
f"URIs for step `{self.step_name}`. Output materializer "
f"keys: {set(output_materializers)}, output artifact URI "
f"keys: {set(output_artifact_uris)}"
)
self._outputs = {
key: StepContextOutput(
materializer_classes=output_materializers[key],
artifact_uri=output_artifact_uris[key],
artifact_config=output_artifact_configs[key],
)
for key in output_materializers.keys()
}
@property
def pipeline(self) -> "PipelineResponse":
"""Returns the current pipeline.
Returns:
The current pipeline or None.
Raises:
StepContextError: If the pipeline run does not have a pipeline.
"""
if self.pipeline_run.pipeline:
return self.pipeline_run.pipeline
raise StepContextError(
f"Unable to get pipeline in step `{self.step_name}` of pipeline "
f"run '{self.pipeline_run.id}': This pipeline run does not have "
f"a pipeline associated with it."
)
@property
def model(self) -> "Model":
"""Returns configured Model.
Order of resolution to search for Model is:
1. Model from the step context
2. Model from the pipeline context
Returns:
The `Model` object associated with the current step.
Raises:
StepContextError: If no `Model` object was specified for the step
or pipeline.
"""
if not self.model_version:
raise StepContextError(
f"Unable to get Model in step `{self.step_name}` of pipeline "
f"run '{self.pipeline_run.id}': No model has been specified "
"the step or pipeline."
)
return self.model_version.to_model_class()
@property
def inputs(self) -> Dict[str, "ArtifactVersionResponse"]:
"""Returns the input artifacts of the current step.
Returns:
The input artifacts of the current step.
"""
return self.step_run.inputs
def _get_output(
self, output_name: Optional[str] = None
) -> "StepContextOutput":
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepContextError(
f"Unable to get step output for step `{self.step_name}`: "
f"This step does not have any outputs."
)
if not output_name and output_count > 1:
raise StepContextError(
f"Unable to get step output for step `{self.step_name}`: "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)
if output_name:
if output_name not in self._outputs:
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step `{self.step_name}`. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
data_type: Optional[Type[Any]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
data_type: If the output annotation is of type `Union` and the step
therefore has multiple materializers configured, you can provide
a data type for the output which will be used to select the
correct materializer. If not provided, the first materializer
will be used.
Returns:
A materializer initialized with the output artifact for
the given output.
"""
from zenml.utils import materializer_utils
output = self._get_output(output_name)
materializer_classes = output.materializer_classes
artifact_uri = output.artifact_uri
if custom_materializer_class:
materializer_class = custom_materializer_class
elif len(materializer_classes) == 1 or not data_type:
materializer_class = materializer_classes[0]
else:
materializer_class = materializer_utils.select_materializer(
data_type=data_type, materializer_classes=materializer_classes
)
return materializer_class(artifact_uri)
def get_output_artifact_uri(
self, output_name: Optional[str] = None
) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
"""
return self._get_output(output_name).artifact_uri
def get_output_metadata(
self, output_name: Optional[str] = None
) -> Dict[str, "MetadataType"]:
"""Returns the metadata for a given step output.
Args:
output_name: Optional name of the output for which to get the
metadata. If no name is given and the step only has a single
output, the metadata of this output will be returned. If the
step has multiple outputs, an exception will be raised.
Returns:
Metadata for the given output.
"""
output = self._get_output(output_name)
custom_metadata = output.run_metadata or {}
if output.artifact_config:
custom_metadata.update(
**(output.artifact_config.run_metadata or {})
)
return custom_metadata
def get_output_tags(self, output_name: Optional[str] = None) -> List[str]:
"""Returns the tags for a given step output.
Args:
output_name: Optional name of the output for which to get the
metadata. If no name is given and the step only has a single
output, the metadata of this output will be returned. If the
step has multiple outputs, an exception will be raised.
Returns:
Tags for the given output.
"""
output = self._get_output(output_name)
custom_tags = set(output.tags or [])
if output.artifact_config:
return list(
set(output.artifact_config.tags or []).union(custom_tags)
)
return list(custom_tags)
def add_output_metadata(
self,
metadata: Dict[str, "MetadataType"],
output_name: Optional[str] = None,
) -> None:
"""Adds metadata for a given step output.
Args:
metadata: The metadata to add.
output_name: Optional name of the output for which to add the
metadata. If no name is given and the step only has a single
output, the metadata of this output will be added. If the
step has multiple outputs, an exception will be raised.
"""
output = self._get_output(output_name)
if not output.run_metadata:
output.run_metadata = {}
output.run_metadata.update(**metadata)
def add_output_tags(
self,
tags: List[str],
output_name: Optional[str] = None,
) -> None:
"""Adds tags for a given step output.
Args:
tags: The tags to add.
output_name: Optional name of the output for which to add the
tags. If no name is given and the step only has a single
output, the tags of this output will be added. If the
step has multiple outputs, an exception will be raised.
"""
output = self._get_output(output_name)
if not output.tags:
output.tags = []
output.tags += tags
inputs: Dict[str, ArtifactVersionResponse]
property
readonly
Returns the input artifacts of the current step.
Returns:
Type | Description |
---|---|
Dict[str, ArtifactVersionResponse] |
The input artifacts of the current step. |
model: Model
property
readonly
Returns configured Model.
Order of resolution to search for Model is: 1. Model from the step context 2. Model from the pipeline context
Returns:
Type | Description |
---|---|
Model |
The |
Exceptions:
Type | Description |
---|---|
StepContextError |
If no |
pipeline: PipelineResponse
property
readonly
Returns the current pipeline.
Returns:
Type | Description |
---|---|
PipelineResponse |
The current pipeline or None. |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the pipeline run does not have a pipeline. |
__init__(self, pipeline_run, step_run, output_materializers, output_artifact_uris, output_artifact_configs)
special
Initialize the context of the currently running step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunResponse |
The model of the current pipeline run. |
required |
step_run |
StepRunResponse |
The model of the current step run. |
required |
output_materializers |
Mapping[str, Sequence[Type[BaseMaterializer]]] |
The output materializers of the step that this context is used in. |
required |
output_artifact_uris |
Mapping[str, str] |
The output artifacts of the step that this context is used in. |
required |
output_artifact_configs |
Mapping[str, Optional[ArtifactConfig]] |
The outputs' ArtifactConfigs of the step that this context is used in. |
required |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the keys of the output materializers and output artifacts do not match. |
Source code in zenml/steps/step_context.py
def __init__(
self,
pipeline_run: "PipelineRunResponse",
step_run: "StepRunResponse",
output_materializers: Mapping[str, Sequence[Type["BaseMaterializer"]]],
output_artifact_uris: Mapping[str, str],
output_artifact_configs: Mapping[str, Optional["ArtifactConfig"]],
) -> None:
"""Initialize the context of the currently running step.
Args:
pipeline_run: The model of the current pipeline run.
step_run: The model of the current step run.
output_materializers: The output materializers of the step that
this context is used in.
output_artifact_uris: The output artifacts of the step that this
context is used in.
output_artifact_configs: The outputs' ArtifactConfigs of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
from zenml.client import Client
try:
pipeline_run = Client().get_pipeline_run(pipeline_run.id)
except KeyError:
pass
self.pipeline_run = pipeline_run
try:
step_run = Client().get_run_step(step_run.id)
except KeyError:
pass
self.step_run = step_run
self.model_version = (
step_run.model_version or pipeline_run.model_version
)
# Get the stack that we are running in
self._stack = Client().active_stack
self.step_name = self.step_run.name
# set outputs
if output_materializers.keys() != output_artifact_uris.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output artifact "
f"URIs for step `{self.step_name}`. Output materializer "
f"keys: {set(output_materializers)}, output artifact URI "
f"keys: {set(output_artifact_uris)}"
)
self._outputs = {
key: StepContextOutput(
materializer_classes=output_materializers[key],
artifact_uri=output_artifact_uris[key],
artifact_config=output_artifact_configs[key],
)
for key in output_materializers.keys()
}
add_output_metadata(self, metadata, output_name=None)
Adds metadata for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to add. |
required |
output_name |
Optional[str] |
Optional name of the output for which to add the metadata. If no name is given and the step only has a single output, the metadata of this output will be added. If the step has multiple outputs, an exception will be raised. |
None |
Source code in zenml/steps/step_context.py
def add_output_metadata(
self,
metadata: Dict[str, "MetadataType"],
output_name: Optional[str] = None,
) -> None:
"""Adds metadata for a given step output.
Args:
metadata: The metadata to add.
output_name: Optional name of the output for which to add the
metadata. If no name is given and the step only has a single
output, the metadata of this output will be added. If the
step has multiple outputs, an exception will be raised.
"""
output = self._get_output(output_name)
if not output.run_metadata:
output.run_metadata = {}
output.run_metadata.update(**metadata)
add_output_tags(self, tags, output_name=None)
Adds tags for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tags |
List[str] |
The tags to add. |
required |
output_name |
Optional[str] |
Optional name of the output for which to add the tags. If no name is given and the step only has a single output, the tags of this output will be added. If the step has multiple outputs, an exception will be raised. |
None |
Source code in zenml/steps/step_context.py
def add_output_tags(
self,
tags: List[str],
output_name: Optional[str] = None,
) -> None:
"""Adds tags for a given step output.
Args:
tags: The tags to add.
output_name: Optional name of the output for which to add the
tags. If no name is given and the step only has a single
output, the tags of this output will be added. If the
step has multiple outputs, an exception will be raised.
"""
output = self._get_output(output_name)
if not output.tags:
output.tags = []
output.tags += tags
get_output_artifact_uri(self, output_name=None)
Returns the artifact URI for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the URI. If no name is given and the step only has a single output, the URI of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
Returns:
Type | Description |
---|---|
str |
Artifact URI for the given output. |
Source code in zenml/steps/step_context.py
def get_output_artifact_uri(
self, output_name: Optional[str] = None
) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
"""
return self._get_output(output_name).artifact_uri
get_output_materializer(self, output_name=None, custom_materializer_class=None, data_type=None)
Returns a materializer for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the materializer. If no name is given and the step only has a single output, the materializer of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
custom_materializer_class |
Optional[Type[BaseMaterializer]] |
If given, this |
None |
data_type |
Optional[Type[Any]] |
If the output annotation is of type |
None |
Returns:
Type | Description |
---|---|
BaseMaterializer |
A materializer initialized with the output artifact for the given output. |
Source code in zenml/steps/step_context.py
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
data_type: Optional[Type[Any]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
data_type: If the output annotation is of type `Union` and the step
therefore has multiple materializers configured, you can provide
a data type for the output which will be used to select the
correct materializer. If not provided, the first materializer
will be used.
Returns:
A materializer initialized with the output artifact for
the given output.
"""
from zenml.utils import materializer_utils
output = self._get_output(output_name)
materializer_classes = output.materializer_classes
artifact_uri = output.artifact_uri
if custom_materializer_class:
materializer_class = custom_materializer_class
elif len(materializer_classes) == 1 or not data_type:
materializer_class = materializer_classes[0]
else:
materializer_class = materializer_utils.select_materializer(
data_type=data_type, materializer_classes=materializer_classes
)
return materializer_class(artifact_uri)
get_output_metadata(self, output_name=None)
Returns the metadata for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the metadata. If no name is given and the step only has a single output, the metadata of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
Metadata for the given output. |
Source code in zenml/steps/step_context.py
def get_output_metadata(
self, output_name: Optional[str] = None
) -> Dict[str, "MetadataType"]:
"""Returns the metadata for a given step output.
Args:
output_name: Optional name of the output for which to get the
metadata. If no name is given and the step only has a single
output, the metadata of this output will be returned. If the
step has multiple outputs, an exception will be raised.
Returns:
Metadata for the given output.
"""
output = self._get_output(output_name)
custom_metadata = output.run_metadata or {}
if output.artifact_config:
custom_metadata.update(
**(output.artifact_config.run_metadata or {})
)
return custom_metadata
get_output_tags(self, output_name=None)
Returns the tags for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the metadata. If no name is given and the step only has a single output, the metadata of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
Returns:
Type | Description |
---|---|
List[str] |
Tags for the given output. |
Source code in zenml/steps/step_context.py
def get_output_tags(self, output_name: Optional[str] = None) -> List[str]:
"""Returns the tags for a given step output.
Args:
output_name: Optional name of the output for which to get the
metadata. If no name is given and the step only has a single
output, the metadata of this output will be returned. If the
step has multiple outputs, an exception will be raised.
Returns:
Tags for the given output.
"""
output = self._get_output(output_name)
custom_tags = set(output.tags or [])
if output.artifact_config:
return list(
set(output.artifact_config.tags or []).union(custom_tags)
)
return list(custom_tags)
StepContextOutput
Represents a step output in the step context.
Source code in zenml/steps/step_context.py
class StepContextOutput:
"""Represents a step output in the step context."""
materializer_classes: Sequence[Type["BaseMaterializer"]]
artifact_uri: str
run_metadata: Optional[Dict[str, "MetadataType"]] = None
artifact_config: Optional["ArtifactConfig"]
tags: Optional[List[str]] = None
def __init__(
self,
materializer_classes: Sequence[Type["BaseMaterializer"]],
artifact_uri: str,
artifact_config: Optional["ArtifactConfig"],
):
"""Initialize the step output.
Args:
materializer_classes: The materializer classes for the output.
artifact_uri: The artifact URI for the output.
artifact_config: The ArtifactConfig object of the output.
"""
self.materializer_classes = materializer_classes
self.artifact_uri = artifact_uri
self.artifact_config = artifact_config
__init__(self, materializer_classes, artifact_uri, artifact_config)
special
Initialize the step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
materializer_classes |
Sequence[Type[BaseMaterializer]] |
The materializer classes for the output. |
required |
artifact_uri |
str |
The artifact URI for the output. |
required |
artifact_config |
Optional[ArtifactConfig] |
The ArtifactConfig object of the output. |
required |
Source code in zenml/steps/step_context.py
def __init__(
self,
materializer_classes: Sequence[Type["BaseMaterializer"]],
artifact_uri: str,
artifact_config: Optional["ArtifactConfig"],
):
"""Initialize the step output.
Args:
materializer_classes: The materializer classes for the output.
artifact_uri: The artifact URI for the output.
artifact_config: The ArtifactConfig object of the output.
"""
self.materializer_classes = materializer_classes
self.artifact_uri = artifact_uri
self.artifact_config = artifact_config
get_step_context()
Get the context of the currently running step.
Returns:
Type | Description |
---|---|
StepContext |
The context of the currently running step. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no step is currently running. |
Source code in zenml/steps/step_context.py
def get_step_context() -> "StepContext":
"""Get the context of the currently running step.
Returns:
The context of the currently running step.
Raises:
RuntimeError: If no step is currently running.
"""
if StepContext._exists():
return StepContext() # type: ignore
raise RuntimeError(
"The step context is only available inside a step function."
)
step_decorator
Step decorator function.
step(_func=None, *, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, model=None, retry=None)
Decorator to create a ZenML step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_func |
Optional[F] |
The decorated function. |
None |
name |
Optional[str] |
The name of the step. If left empty, the name of the decorated function will be used as a fallback. |
None |
enable_cache |
Optional[bool] |
Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default. |
None |
enable_artifact_metadata |
Optional[bool] |
Specify whether metadata is enabled for this step. If no value is passed, metadata is enabled by default. |
None |
enable_artifact_visualization |
Optional[bool] |
Specify whether visualization is enabled for this step. If no value is passed, visualization is enabled by default. |
None |
enable_step_logs |
Optional[bool] |
Specify whether step logs are enabled for this step. |
None |
experiment_tracker |
Optional[str] |
The experiment tracker to use for this step. |
None |
step_operator |
Optional[str] |
The step operator to use for this step. |
None |
output_materializers |
Optional[OutputMaterializersSpecification] |
Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs. |
None |
settings |
Optional[Dict[str, SettingsOrDict]] |
Settings for this step. |
None |
extra |
Optional[Dict[str, Any]] |
Extra configurations for this step. |
None |
on_failure |
Optional[HookSpecification] |
Callback function in event of failure of the step. Can be a
function with a single argument of type |
None |
on_success |
Optional[HookSpecification] |
Callback function in event of success of the step. Can be a
function with no arguments, or a source path to such a function
(e.g. |
None |
model |
Optional[Model] |
configuration of the model in the Model Control Plane. |
None |
retry |
Optional[StepRetryConfig] |
configuration of step retry in case of step failure. |
None |
Returns:
Type | Description |
---|---|
Union[BaseStep, Callable[[F], BaseStep]] |
The step instance. |
Source code in zenml/steps/step_decorator.py
def step(
_func: Optional["F"] = None,
*,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
enable_step_logs: Optional[bool] = None,
experiment_tracker: Optional[str] = None,
step_operator: Optional[str] = None,
output_materializers: Optional["OutputMaterializersSpecification"] = None,
settings: Optional[Dict[str, "SettingsOrDict"]] = None,
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["Model"] = None,
retry: Optional["StepRetryConfig"] = None,
) -> Union["BaseStep", Callable[["F"], "BaseStep"]]:
"""Decorator to create a ZenML step.
Args:
_func: The decorated function.
name: The name of the step. If left empty, the name of the decorated
function will be used as a fallback.
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default.
enable_artifact_metadata: Specify whether metadata is enabled for this
step. If no value is passed, metadata is enabled by default.
enable_artifact_visualization: Specify whether visualization is enabled
for this step. If no value is passed, visualization is enabled by
default.
enable_step_logs: Specify whether step logs are enabled for this step.
experiment_tracker: The experiment tracker to use for this step.
step_operator: The step operator to use for this step.
output_materializers: Output materializers for this step. If
given as a dict, the keys must be a subset of the output names
of this step. If a single value (type or string) is given, the
materializer will be used for all outputs.
settings: Settings for this step.
extra: Extra configurations for this step.
on_failure: Callback function in event of failure of the step. Can be a
function with a single argument of type `BaseException`, or a source
path to such a function (e.g. `module.my_function`).
on_success: Callback function in event of success of the step. Can be a
function with no arguments, or a source path to such a function
(e.g. `module.my_function`).
model: configuration of the model in the Model Control Plane.
retry: configuration of step retry in case of step failure.
Returns:
The step instance.
"""
def inner_decorator(func: "F") -> "BaseStep":
from zenml.steps.decorated_step import _DecoratedStep
class_: Type["BaseStep"] = type(
func.__name__,
(_DecoratedStep,),
{
"entrypoint": staticmethod(func),
"__module__": func.__module__,
"__doc__": func.__doc__,
},
)
step_instance = class_(
name=name or func.__name__,
enable_cache=enable_cache,
enable_artifact_metadata=enable_artifact_metadata,
enable_artifact_visualization=enable_artifact_visualization,
enable_step_logs=enable_step_logs,
experiment_tracker=experiment_tracker,
step_operator=step_operator,
output_materializers=output_materializers,
settings=settings,
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
retry=retry,
)
return step_instance
if _func is None:
return inner_decorator
else:
return inner_decorator(_func)
step_invocation
Step invocation class definition.
StepInvocation
Step invocation class.
Source code in zenml/steps/step_invocation.py
class StepInvocation:
"""Step invocation class."""
def __init__(
self,
id: str,
step: "BaseStep",
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifact"],
model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"],
client_lazy_loaders: Dict[str, "ClientLazyLoader"],
parameters: Dict[str, Any],
default_parameters: Dict[str, Any],
upstream_steps: Set[str],
pipeline: "Pipeline",
) -> None:
"""Initialize a step invocation.
Args:
id: The invocation ID.
step: The step that is represented by the invocation.
input_artifacts: The input artifacts for the invocation.
external_artifacts: The external artifacts for the invocation.
model_artifacts_or_metadata: The model artifacts or metadata for
the invocation.
client_lazy_loaders: The client lazy loaders for the invocation.
parameters: The parameters for the invocation.
default_parameters: The default parameters for the invocation.
upstream_steps: The upstream steps for the invocation.
pipeline: The parent pipeline of the invocation.
"""
self.id = id
self.step = step
self.input_artifacts = input_artifacts
self.external_artifacts = external_artifacts
self.model_artifacts_or_metadata = model_artifacts_or_metadata
self.client_lazy_loaders = client_lazy_loaders
self.parameters = parameters
self.default_parameters = default_parameters
self.upstream_steps = upstream_steps
self.pipeline = pipeline
def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration":
"""Finalizes a step invocation.
It will validate the upstream steps and run final configurations on the
step that is represented by the invocation.
Args:
parameters_to_ignore: Set of parameters that should not be applied
to the step instance.
Returns:
The finalized step configuration.
"""
from zenml.artifacts.external_artifact_config import (
ExternalArtifactConfiguration,
)
parameters_to_apply = {
key: value
for key, value in self.parameters.items()
if key not in parameters_to_ignore
}
parameters_to_apply.update(
{
key: value
for key, value in self.default_parameters.items()
if key not in parameters_to_ignore
and key not in parameters_to_apply
}
)
self.step.configure(parameters=parameters_to_apply)
external_artifacts: Dict[str, ExternalArtifactConfiguration] = {}
for key, artifact in self.external_artifacts.items():
if artifact.value is not None:
artifact.upload_by_value()
external_artifacts[key] = artifact.config
return self.step._finalize_configuration(
input_artifacts=self.input_artifacts,
external_artifacts=external_artifacts,
model_artifacts_or_metadata=self.model_artifacts_or_metadata,
client_lazy_loaders=self.client_lazy_loaders,
)
__init__(self, id, step, input_artifacts, external_artifacts, model_artifacts_or_metadata, client_lazy_loaders, parameters, default_parameters, upstream_steps, pipeline)
special
Initialize a step invocation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
str |
The invocation ID. |
required |
step |
BaseStep |
The step that is represented by the invocation. |
required |
input_artifacts |
Dict[str, StepArtifact] |
The input artifacts for the invocation. |
required |
external_artifacts |
Dict[str, ExternalArtifact] |
The external artifacts for the invocation. |
required |
model_artifacts_or_metadata |
Dict[str, ModelVersionDataLazyLoader] |
The model artifacts or metadata for the invocation. |
required |
client_lazy_loaders |
Dict[str, ClientLazyLoader] |
The client lazy loaders for the invocation. |
required |
parameters |
Dict[str, Any] |
The parameters for the invocation. |
required |
default_parameters |
Dict[str, Any] |
The default parameters for the invocation. |
required |
upstream_steps |
Set[str] |
The upstream steps for the invocation. |
required |
pipeline |
Pipeline |
The parent pipeline of the invocation. |
required |
Source code in zenml/steps/step_invocation.py
def __init__(
self,
id: str,
step: "BaseStep",
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[str, "ExternalArtifact"],
model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"],
client_lazy_loaders: Dict[str, "ClientLazyLoader"],
parameters: Dict[str, Any],
default_parameters: Dict[str, Any],
upstream_steps: Set[str],
pipeline: "Pipeline",
) -> None:
"""Initialize a step invocation.
Args:
id: The invocation ID.
step: The step that is represented by the invocation.
input_artifacts: The input artifacts for the invocation.
external_artifacts: The external artifacts for the invocation.
model_artifacts_or_metadata: The model artifacts or metadata for
the invocation.
client_lazy_loaders: The client lazy loaders for the invocation.
parameters: The parameters for the invocation.
default_parameters: The default parameters for the invocation.
upstream_steps: The upstream steps for the invocation.
pipeline: The parent pipeline of the invocation.
"""
self.id = id
self.step = step
self.input_artifacts = input_artifacts
self.external_artifacts = external_artifacts
self.model_artifacts_or_metadata = model_artifacts_or_metadata
self.client_lazy_loaders = client_lazy_loaders
self.parameters = parameters
self.default_parameters = default_parameters
self.upstream_steps = upstream_steps
self.pipeline = pipeline
finalize(self, parameters_to_ignore)
Finalizes a step invocation.
It will validate the upstream steps and run final configurations on the step that is represented by the invocation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameters_to_ignore |
Set[str] |
Set of parameters that should not be applied to the step instance. |
required |
Returns:
Type | Description |
---|---|
StepConfiguration |
The finalized step configuration. |
Source code in zenml/steps/step_invocation.py
def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration":
"""Finalizes a step invocation.
It will validate the upstream steps and run final configurations on the
step that is represented by the invocation.
Args:
parameters_to_ignore: Set of parameters that should not be applied
to the step instance.
Returns:
The finalized step configuration.
"""
from zenml.artifacts.external_artifact_config import (
ExternalArtifactConfiguration,
)
parameters_to_apply = {
key: value
for key, value in self.parameters.items()
if key not in parameters_to_ignore
}
parameters_to_apply.update(
{
key: value
for key, value in self.default_parameters.items()
if key not in parameters_to_ignore
and key not in parameters_to_apply
}
)
self.step.configure(parameters=parameters_to_apply)
external_artifacts: Dict[str, ExternalArtifactConfiguration] = {}
for key, artifact in self.external_artifacts.items():
if artifact.value is not None:
artifact.upload_by_value()
external_artifacts[key] = artifact.config
return self.step._finalize_configuration(
input_artifacts=self.input_artifacts,
external_artifacts=external_artifacts,
model_artifacts_or_metadata=self.model_artifacts_or_metadata,
client_lazy_loaders=self.client_lazy_loaders,
)
utils
Utility functions and classes to run ZenML steps.
OnlyNoneReturnsVisitor (ReturnVisitor)
Checks whether a function AST contains only None
returns.
Source code in zenml/steps/utils.py
class OnlyNoneReturnsVisitor(ReturnVisitor):
"""Checks whether a function AST contains only `None` returns."""
def __init__(self) -> None:
"""Initializes a visitor instance."""
super().__init__()
self.has_only_none_returns = True
def visit_Return(self, node: ast.Return) -> None:
"""Visit a return statement.
Args:
node: The return statement to visit.
"""
if node.value is not None:
if isinstance(node.value, (ast.Constant, ast.NameConstant)):
if node.value.value is None:
return
self.has_only_none_returns = False
__init__(self)
special
Initializes a visitor instance.
Source code in zenml/steps/utils.py
def __init__(self) -> None:
"""Initializes a visitor instance."""
super().__init__()
self.has_only_none_returns = True
visit_Return(self, node)
Visit a return statement.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
Return |
The return statement to visit. |
required |
Source code in zenml/steps/utils.py
def visit_Return(self, node: ast.Return) -> None:
"""Visit a return statement.
Args:
node: The return statement to visit.
"""
if node.value is not None:
if isinstance(node.value, (ast.Constant, ast.NameConstant)):
if node.value.value is None:
return
self.has_only_none_returns = False
OutputSignature (BaseModel)
The signature of an output artifact.
Source code in zenml/steps/utils.py
class OutputSignature(BaseModel):
"""The signature of an output artifact."""
resolved_annotation: Any = None
artifact_config: Optional[ArtifactConfig] = None
has_custom_name: bool = False
def get_output_types(self) -> Tuple[Any, ...]:
"""Get all output types that match the type annotation.
Returns:
All output types that match the type annotation.
"""
if self.resolved_annotation is Any:
return ()
if typing_utils.is_union(
typing_utils.get_origin(self.resolved_annotation)
or self.resolved_annotation
):
return tuple(
type(None)
if typing_utils.is_none_type(output_type)
else output_type
for output_type in get_args(self.resolved_annotation)
)
else:
return (self.resolved_annotation,)
get_output_types(self)
Get all output types that match the type annotation.
Returns:
Type | Description |
---|---|
Tuple[Any, ...] |
All output types that match the type annotation. |
Source code in zenml/steps/utils.py
def get_output_types(self) -> Tuple[Any, ...]:
"""Get all output types that match the type annotation.
Returns:
All output types that match the type annotation.
"""
if self.resolved_annotation is Any:
return ()
if typing_utils.is_union(
typing_utils.get_origin(self.resolved_annotation)
or self.resolved_annotation
):
return tuple(
type(None)
if typing_utils.is_none_type(output_type)
else output_type
for output_type in get_args(self.resolved_annotation)
)
else:
return (self.resolved_annotation,)
ReturnVisitor (NodeVisitor)
AST visitor class that can be subclassed to visit function returns.
Source code in zenml/steps/utils.py
class ReturnVisitor(ast.NodeVisitor):
"""AST visitor class that can be subclassed to visit function returns."""
def __init__(self, ignore_nested_functions: bool = True) -> None:
"""Initializes a return visitor instance.
Args:
ignore_nested_functions: If `True`, will skip visiting nested
functions.
"""
self._ignore_nested_functions = ignore_nested_functions
self._inside_function = False
def _visit_function(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
"""Visit a (async) function definition node.
Args:
node: The node to visit.
"""
if self._ignore_nested_functions and self._inside_function:
# We're already inside a function definition and should ignore
# nested functions so we don't want to recurse any further
return
self._inside_function = True
self.generic_visit(node)
visit_FunctionDef = _visit_function
visit_AsyncFunctionDef = _visit_function
__init__(self, ignore_nested_functions=True)
special
Initializes a return visitor instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ignore_nested_functions |
bool |
If |
True |
Source code in zenml/steps/utils.py
def __init__(self, ignore_nested_functions: bool = True) -> None:
"""Initializes a return visitor instance.
Args:
ignore_nested_functions: If `True`, will skip visiting nested
functions.
"""
self._ignore_nested_functions = ignore_nested_functions
self._inside_function = False
visit_AsyncFunctionDef(self, node)
Visit a (async) function definition node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
Union[ast.FunctionDef, ast.AsyncFunctionDef] |
The node to visit. |
required |
Source code in zenml/steps/utils.py
def _visit_function(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
"""Visit a (async) function definition node.
Args:
node: The node to visit.
"""
if self._ignore_nested_functions and self._inside_function:
# We're already inside a function definition and should ignore
# nested functions so we don't want to recurse any further
return
self._inside_function = True
self.generic_visit(node)
visit_FunctionDef(self, node)
Visit a (async) function definition node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
Union[ast.FunctionDef, ast.AsyncFunctionDef] |
The node to visit. |
required |
Source code in zenml/steps/utils.py
def _visit_function(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
"""Visit a (async) function definition node.
Args:
node: The node to visit.
"""
if self._ignore_nested_functions and self._inside_function:
# We're already inside a function definition and should ignore
# nested functions so we don't want to recurse any further
return
self._inside_function = True
self.generic_visit(node)
TupleReturnVisitor (ReturnVisitor)
Checks whether a function AST contains tuple returns.
Source code in zenml/steps/utils.py
class TupleReturnVisitor(ReturnVisitor):
"""Checks whether a function AST contains tuple returns."""
def __init__(self) -> None:
"""Initializes a visitor instance."""
super().__init__()
self.has_tuple_return = False
def visit_Return(self, node: ast.Return) -> None:
"""Visit a return statement.
Args:
node: The return statement to visit.
"""
if isinstance(node.value, ast.Tuple) and len(node.value.elts) > 1:
self.has_tuple_return = True
__init__(self)
special
Initializes a visitor instance.
Source code in zenml/steps/utils.py
def __init__(self) -> None:
"""Initializes a visitor instance."""
super().__init__()
self.has_tuple_return = False
visit_Return(self, node)
Visit a return statement.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
Return |
The return statement to visit. |
required |
Source code in zenml/steps/utils.py
def visit_Return(self, node: ast.Return) -> None:
"""Visit a return statement.
Args:
node: The return statement to visit.
"""
if isinstance(node.value, ast.Tuple) and len(node.value.elts) > 1:
self.has_tuple_return = True
get_args(obj)
Get arguments of a type annotation.
Examples:
get_args(Union[int, str]) == (int, str)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The annotation. |
required |
Returns:
Type | Description |
---|---|
Tuple[Any, ...] |
The args of the annotation. |
Source code in zenml/steps/utils.py
def get_args(obj: Any) -> Tuple[Any, ...]:
"""Get arguments of a type annotation.
Example:
`get_args(Union[int, str]) == (int, str)`
Args:
obj: The annotation.
Returns:
The args of the annotation.
"""
return tuple(
typing_utils.get_origin(v) or v for v in typing_utils.get_args(obj)
)
get_artifact_config_from_annotation_metadata(annotation)
Get the artifact config from the annotation metadata of a step output.
Examples:
get_output_name_from_annotation_metadata(int) # None
get_output_name_from_annotation_metadata(Annotated[int, "name"] # ArtifactConfig(name="name")
get_output_name_from_annotation_metadata(Annotated[int, ArtifactConfig(name="name", model_name="foo")] # ArtifactConfig(name="name", model_name="foo")
Parameters:
Name | Type | Description | Default |
---|---|---|---|
annotation |
Any |
The type annotation. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the annotation is not following the expected format or if the name was specified multiple times or is an empty string. |
Returns:
Type | Description |
---|---|
Optional[zenml.artifacts.artifact_config.ArtifactConfig] |
The artifact config. |
Source code in zenml/steps/utils.py
def get_artifact_config_from_annotation_metadata(
annotation: Any,
) -> Optional[ArtifactConfig]:
"""Get the artifact config from the annotation metadata of a step output.
Example:
```python
get_output_name_from_annotation_metadata(int) # None
get_output_name_from_annotation_metadata(Annotated[int, "name"] # ArtifactConfig(name="name")
get_output_name_from_annotation_metadata(Annotated[int, ArtifactConfig(name="name", model_name="foo")] # ArtifactConfig(name="name", model_name="foo")
```
Args:
annotation: The type annotation.
Raises:
ValueError: If the annotation is not following the expected format
or if the name was specified multiple times or is an empty string.
Returns:
The artifact config.
"""
if (typing_utils.get_origin(annotation) or annotation) is not Annotated:
return None
annotation, *metadata = typing_utils.get_args(annotation)
error_message = (
"Artifact annotation should only contain two elements: the artifact "
"type, and either an output name or an `ArtifactConfig`, e.g.: "
"`Annotated[int, 'output_name']` or "
"`Annotated[int, ArtifactConfig(name='output_name'), ...]`."
)
if len(metadata) > 2:
raise ValueError(error_message)
# Loop over all values to also support legacy annotations of the form
# `Annotated[int, 'output_name', ArtifactConfig(...)]`
output_name = None
artifact_config = None
for metadata_instance in metadata:
if isinstance(metadata_instance, str):
if output_name is not None:
raise ValueError(error_message)
output_name = metadata_instance
elif isinstance(metadata_instance, ArtifactConfig):
if artifact_config is not None:
raise ValueError(error_message)
artifact_config = metadata_instance
else:
raise ValueError(error_message)
# Consolidate output name
if artifact_config and artifact_config.name:
if output_name is not None:
raise ValueError(error_message)
elif output_name:
if not artifact_config:
artifact_config = ArtifactConfig(name=output_name)
elif not artifact_config.name:
artifact_config = artifact_config.model_copy()
artifact_config.name = output_name
if artifact_config and artifact_config.name == "":
raise ValueError("Output name cannot be an empty string.")
return artifact_config
has_only_none_returns(func)
Checks whether a function contains only None
returns.
A None
return could be either an explicit return None
or an empty
return
statement.
Examples:
def f1():
return None
def f2():
return
def f3(condition):
if condition:
return None
else:
return 1
has_only_none_returns(f1) # True
has_only_none_returns(f2) # True
has_only_none_returns(f3) # False
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., Any] |
The function to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the function contains only |
Source code in zenml/steps/utils.py
def has_only_none_returns(func: Callable[..., Any]) -> bool:
"""Checks whether a function contains only `None` returns.
A `None` return could be either an explicit `return None` or an empty
`return` statement.
Example:
```python
def f1():
return None
def f2():
return
def f3(condition):
if condition:
return None
else:
return 1
has_only_none_returns(f1) # True
has_only_none_returns(f2) # True
has_only_none_returns(f3) # False
```
Args:
func: The function to check.
Returns:
Whether the function contains only `None` returns.
"""
source = textwrap.dedent(source_code_utils.get_source_code(func))
tree = ast.parse(source)
visitor = OnlyNoneReturnsVisitor()
visitor.visit(tree)
return visitor.has_only_none_returns
has_tuple_return(func)
Checks whether a function returns multiple values.
Multiple values means that the return
statement is followed by a tuple
(with or without brackets).
Examples:
def f1():
return 1, 2
def f2():
return (1, 2)
def f3():
var = (1, 2)
return var
has_tuple_return(f1) # True
has_tuple_return(f2) # True
has_tuple_return(f3) # False
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., Any] |
The function to check. |
required |
Returns:
Type | Description |
---|---|
bool |
Whether the function returns multiple values. |
Source code in zenml/steps/utils.py
def has_tuple_return(func: Callable[..., Any]) -> bool:
"""Checks whether a function returns multiple values.
Multiple values means that the `return` statement is followed by a tuple
(with or without brackets).
Example:
```python
def f1():
return 1, 2
def f2():
return (1, 2)
def f3():
var = (1, 2)
return var
has_tuple_return(f1) # True
has_tuple_return(f2) # True
has_tuple_return(f3) # False
```
Args:
func: The function to check.
Returns:
Whether the function returns multiple values.
"""
source = textwrap.dedent(source_code_utils.get_source_code(func))
tree = ast.parse(source)
visitor = TupleReturnVisitor()
visitor.visit(tree)
return visitor.has_tuple_return
log_step_metadata(metadata, step_name=None, pipeline_name_id_or_prefix=None, run_id=None)
Logs step metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metadata |
Dict[str, MetadataType] |
The metadata to log. |
required |
step_name |
Optional[str] |
The name of the step to log metadata for. Can be omitted when being called inside a step. |
None |
pipeline_name_id_or_prefix |
Union[uuid.UUID, str] |
The name of the pipeline to log metadata for. Can be omitted when being called inside a step. |
None |
run_id |
Optional[str] |
The ID of the run to log metadata for. Can be omitted when being called inside a step. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If no step name is provided and the function is not called from within a step or if no pipeline name or ID is provided and the function is not called from within a step. |
Source code in zenml/steps/utils.py
def log_step_metadata(
metadata: Dict[str, "MetadataType"],
step_name: Optional[str] = None,
pipeline_name_id_or_prefix: Optional[Union[str, UUID]] = None,
run_id: Optional[str] = None,
) -> None:
"""Logs step metadata.
Args:
metadata: The metadata to log.
step_name: The name of the step to log metadata for. Can be omitted
when being called inside a step.
pipeline_name_id_or_prefix: The name of the pipeline to log metadata
for. Can be omitted when being called inside a step.
run_id: The ID of the run to log metadata for. Can be omitted when
being called inside a step.
Raises:
ValueError: If no step name is provided and the function is not called
from within a step or if no pipeline name or ID is provided and
the function is not called from within a step.
"""
step_context = None
if not step_name:
with contextlib.suppress(RuntimeError):
step_context = get_step_context()
step_name = step_context.step_name
# not running within a step and no user-provided step name
if not step_name:
raise ValueError(
"No step name provided and you are not running "
"within a step. Please provide a step name."
)
client = Client()
if step_context:
step_run_id = step_context.step_run.id
elif run_id:
step_run_id = UUID(int=int(run_id))
else:
if not pipeline_name_id_or_prefix:
raise ValueError(
"No pipeline name or ID provided and you are not running "
"within a step. Please provide a pipeline name or ID, or "
"provide a run ID."
)
pipeline_run = client.get_pipeline(
name_id_or_prefix=pipeline_name_id_or_prefix,
).last_run
step_run_id = pipeline_run.steps[step_name].id
client.create_run_metadata(
metadata=metadata,
resource_id=step_run_id,
resource_type=MetadataResourceTypes.STEP_RUN,
)
parse_return_type_annotations(func, enforce_type_annotations=False)
Parse the return type annotation of a step function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., Any] |
The step function. |
required |
enforce_type_annotations |
bool |
If |
False |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the output annotation has variable length or contains duplicate output names. |
RuntimeError |
If type annotations should be enforced and a type annotation is missing. |
Returns:
Type | Description |
---|---|
Dict[str, zenml.steps.utils.OutputSignature] |
|
Source code in zenml/steps/utils.py
def parse_return_type_annotations(
func: Callable[..., Any], enforce_type_annotations: bool = False
) -> Dict[str, OutputSignature]:
"""Parse the return type annotation of a step function.
Args:
func: The step function.
enforce_type_annotations: If `True`, raises an exception if a type
annotation is missing.
Raises:
RuntimeError: If the output annotation has variable length or contains
duplicate output names.
RuntimeError: If type annotations should be enforced and a type
annotation is missing.
Returns:
- A dictionary mapping output names to their output signatures.
"""
signature = inspect.signature(func, follow_wrapped=True)
return_annotation = signature.return_annotation
output_name: Optional[str]
# Return type annotated as `None`
if return_annotation is None:
return {}
# Return type not annotated -> check whether `None` or `Any` should be used
if return_annotation is signature.empty:
if enforce_type_annotations:
raise RuntimeError(
"Missing return type annotation for step function "
f"'{func.__name__}'."
)
elif has_only_none_returns(func):
return {}
else:
return_annotation = Any
if typing_utils.get_origin(return_annotation) is tuple:
requires_multiple_artifacts = has_tuple_return(func)
if requires_multiple_artifacts:
output_signature: Dict[str, Any] = {}
args = typing_utils.get_args(return_annotation)
if args[-1] is Ellipsis:
raise RuntimeError(
"Variable length output annotations are not allowed."
)
for i, annotation in enumerate(args):
resolved_annotation = resolve_type_annotation(annotation)
artifact_config = get_artifact_config_from_annotation_metadata(
annotation
)
output_name = artifact_config.name if artifact_config else None
has_custom_name = output_name is not None
output_name = output_name or f"output_{i}"
if output_name in output_signature:
raise RuntimeError(f"Duplicate output name {output_name}.")
output_signature[output_name] = OutputSignature(
resolved_annotation=resolved_annotation,
artifact_config=artifact_config,
has_custom_name=has_custom_name,
)
return output_signature
# Return type is annotated as single value or is a tuple
resolved_annotation = resolve_type_annotation(return_annotation)
artifact_config = get_artifact_config_from_annotation_metadata(
return_annotation
)
output_name = artifact_config.name if artifact_config else None
has_custom_name = output_name is not None
output_name = output_name or SINGLE_RETURN_OUT_NAME
return {
output_name: OutputSignature(
resolved_annotation=resolved_annotation,
artifact_config=artifact_config,
has_custom_name=has_custom_name,
)
}
resolve_type_annotation(obj)
Returns the non-generic class for generic aliases of the typing module.
Example: if the input object is typing.Dict
, this method will return the
concrete class dict
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The object to resolve. |
required |
Returns:
Type | Description |
---|---|
Any |
The non-generic class for generic aliases of the typing module. |
Source code in zenml/steps/utils.py
def resolve_type_annotation(obj: Any) -> Any:
"""Returns the non-generic class for generic aliases of the typing module.
Example: if the input object is `typing.Dict`, this method will return the
concrete class `dict`.
Args:
obj: The object to resolve.
Returns:
The non-generic class for generic aliases of the typing module.
"""
origin = typing_utils.get_origin(obj) or obj
if origin is Annotated:
annotation, *_ = typing_utils.get_args(obj)
return resolve_type_annotation(annotation)
elif typing_utils.is_union(origin):
return obj
return origin
run_as_single_step_pipeline(__step, *args, **kwargs)
Runs the step as a single step pipeline.
- All inputs that are not JSON serializable will be uploaded to the artifact store before the pipeline is being executed.
- All output artifacts of the step will be loaded using the materializer that was used to store them.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Entrypoint function arguments. |
() |
**kwargs |
Any |
Entrypoint function keyword arguments. |
{} |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the step execution failed. |
StepInterfaceError |
If the arguments to the entrypoint function are invalid. |
Returns:
Type | Description |
---|---|
Any |
The output of the step entrypoint function. |
Source code in zenml/steps/utils.py
def run_as_single_step_pipeline(
__step: "BaseStep", *args: Any, **kwargs: Any
) -> Any:
"""Runs the step as a single step pipeline.
- All inputs that are not JSON serializable will be uploaded to the
artifact store before the pipeline is being executed.
- All output artifacts of the step will be loaded using the materializer
that was used to store them.
Args:
*args: Entrypoint function arguments.
**kwargs: Entrypoint function keyword arguments.
Raises:
RuntimeError: If the step execution failed.
StepInterfaceError: If the arguments to the entrypoint function are
invalid.
Returns:
The output of the step entrypoint function.
"""
from zenml import ExternalArtifact, pipeline
from zenml.config.base_settings import BaseSettings
from zenml.pipelines.run_utils import (
wait_for_pipeline_run_to_finish,
)
logger.info(
"Running single step pipeline to execute step `%s`", __step.name
)
try:
validated_arguments = (
inspect.signature(__step.entrypoint)
.bind(*args, **kwargs)
.arguments
)
except TypeError as e:
raise StepInterfaceError(
"Invalid step function entrypoint arguments. Check out the "
"error above for more details."
) from e
inputs: Dict[str, Any] = {}
for key, value in validated_arguments.items():
try:
__step.entrypoint_definition.validate_input(key=key, value=value)
inputs[key] = value
except Exception:
inputs[key] = ExternalArtifact(value=value)
orchestrator = Client().active_stack.orchestrator
pipeline_settings: Any = {}
if "synchronous" in orchestrator.config.model_fields:
# Make sure the orchestrator runs sync so we stream the logs
key = settings_utils.get_stack_component_setting_key(orchestrator)
pipeline_settings[key] = BaseSettings(synchronous=True)
@pipeline(name=__step.name, enable_cache=False, settings=pipeline_settings)
def single_step_pipeline() -> None:
__step(**inputs)
run = single_step_pipeline.with_options(unlisted=True)()
run = wait_for_pipeline_run_to_finish(run.id)
if run.status != ExecutionStatus.COMPLETED:
raise RuntimeError("Failed to execute step %s.", __step.name)
# 4. Load output artifacts
step_run = next(iter(run.steps.values()))
outputs = [
step_run.outputs[output_name].load()
for output_name in step_run.config.outputs.keys()
]
if len(outputs) == 0:
return None
elif len(outputs) == 1:
return outputs[0]
else:
return tuple(outputs)