Step Operators
zenml.step_operators
special
Step operators allow you to run steps on custom infrastructure.
While an orchestrator defines how and where your entire pipeline runs, a step operator defines how and where an individual step runs. This can be useful in a variety of scenarios. An example could be if one step within a pipeline should run on a separate environment equipped with a GPU (like a trainer step).
base_step_operator
Base class for ZenML step operators.
BaseStepOperator (StackComponent, ABC)
Base class for all ZenML step operators.
Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperator(StackComponent, ABC):
"""Base class for all ZenML step operators."""
@property
def config(self) -> BaseStepOperatorConfig:
"""Returns the config of the step operator.
Returns:
The config of the step operator.
"""
return cast(BaseStepOperatorConfig, self._config)
@abstractmethod
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Abstract method to execute a step.
Subclasses must implement this method and launch a **synchronous**
job that executes the `entrypoint_command`.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
config: BaseStepOperatorConfig
property
readonly
Returns the config of the step operator.
Returns:
Type | Description |
---|---|
BaseStepOperatorConfig |
The config of the step operator. |
launch(self, info, entrypoint_command)
Abstract method to execute a step.
Subclasses must implement this method and launch a synchronous
job that executes the entrypoint_command
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Information about the step run. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
Source code in zenml/step_operators/base_step_operator.py
@abstractmethod
def launch(
self,
info: "StepRunInfo",
entrypoint_command: List[str],
) -> None:
"""Abstract method to execute a step.
Subclasses must implement this method and launch a **synchronous**
job that executes the `entrypoint_command`.
Args:
info: Information about the step run.
entrypoint_command: Command that executes the step.
"""
BaseStepOperatorConfig (StackComponentConfig)
pydantic-model
Base config for step operators.
Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperatorConfig(StackComponentConfig):
"""Base config for step operators."""
@root_validator(pre=True)
def _deprecations(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate and/or remove deprecated fields.
Args:
values: The values to validate.
Returns:
The validated values.
"""
if "base_image" in values:
image_name = values.pop("base_image", None)
if image_name:
logger.warning(
"The 'base_image' field has been deprecated. To use a "
"custom base container image with your "
"step operators, please use the DockerSettings in your "
"pipeline (see https://docs.zenml.io/advanced-guide/pipelines/containerization)."
)
return values
BaseStepOperatorFlavor (Flavor)
Base class for all ZenML step operator flavors.
Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperatorFlavor(Flavor):
"""Base class for all ZenML step operator flavors."""
@property
def type(self) -> StackComponentType:
"""Returns the flavor type.
Returns:
The type of the flavor.
"""
return StackComponentType.STEP_OPERATOR
@property
def config_class(self) -> Type[BaseStepOperatorConfig]:
"""Returns the config class for this flavor.
Returns:
The config class for this flavor.
"""
return BaseStepOperatorConfig
@property
@abstractmethod
def implementation_class(self) -> Type[BaseStepOperator]:
"""Returns the implementation class for this flavor.
Returns:
The implementation class for this flavor.
"""
config_class: Type[zenml.step_operators.base_step_operator.BaseStepOperatorConfig]
property
readonly
Returns the config class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.step_operators.base_step_operator.BaseStepOperatorConfig] |
The config class for this flavor. |
implementation_class: Type[zenml.step_operators.base_step_operator.BaseStepOperator]
property
readonly
Returns the implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.step_operators.base_step_operator.BaseStepOperator] |
The implementation class for this flavor. |
type: StackComponentType
property
readonly
Returns the flavor type.
Returns:
Type | Description |
---|---|
StackComponentType |
The type of the flavor. |
step_executor_operator
Custom StepExecutorOperator which can be passed to the step operator.
StepExecutorOperator (BaseExecutorOperator)
StepExecutorOperator extends TFX's BaseExecutorOperator.
This class can be passed as a custom executor operator during a pipeline run which will then be used to call the step's configured step operator to launch it in some environment.
Source code in zenml/step_operators/step_executor_operator.py
class StepExecutorOperator(BaseExecutorOperator):
"""StepExecutorOperator extends TFX's BaseExecutorOperator.
This class can be passed as a custom executor operator during
a pipeline run which will then be used to call the step's
configured step operator to launch it in some environment.
"""
SUPPORTED_EXECUTOR_SPEC_TYPE = [
executable_spec_pb2.PythonClassExecutableSpec
]
SUPPORTED_PLATFORM_CONFIG_TYPE: List[Any] = []
@staticmethod
def _get_step_operator(
stack: "Stack", step_operator_name: str
) -> "BaseStepOperator":
"""Fetches the step operator specified in the execution info.
Args:
stack: Stack on which the step is being executed.
step_operator_name: Name of the step operator to get.
Returns:
The step operator to run a step.
Raises:
RuntimeError: If no active step operator is found.
"""
step_operator = stack.step_operator
# the two following errors should never happen as the stack gets
# validated before running the pipeline
if not step_operator:
raise RuntimeError(
f"No step operator specified for active stack '{stack.name}'."
)
if step_operator_name != step_operator.name:
raise RuntimeError(
f"No step operator named '{step_operator_name}' in active "
f"stack '{stack.name}'."
)
return step_operator
@staticmethod
def _get_step_name_in_pipeline(
execution_info: data_types.ExecutionInfo,
) -> str:
"""Gets the name of a step inside its pipeline.
Args:
execution_info: The step execution info.
Returns:
The name of the step in the pipeline.
"""
property_name = (
INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
)
return cast(
str, json.loads(execution_info.exec_properties[property_name])
)
def run_executor(
self,
execution_info: data_types.ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
"""Invokes the executor with inputs provided by the Launcher.
Args:
execution_info: Necessary information to run the executor.
Returns:
The executor output.
"""
# Pretty sure these attributes will always be not None, assert here so
# mypy doesn't complain
assert execution_info.pipeline_node
assert execution_info.pipeline_info
assert execution_info.pipeline_run_id
assert execution_info.tmp_dir
assert execution_info.execution_output_uri
step = proto_utils.get_step(pipeline_node=execution_info.pipeline_node)
pipeline_config = proto_utils.get_pipeline_config(
pipeline_node=execution_info.pipeline_node
)
assert step.config.step_operator
stack = Client().active_stack
step_operator = self._get_step_operator(
stack=stack, step_operator_name=step.config.step_operator
)
# Write the execution info to a temporary directory inside the artifact
# store so the step operator entrypoint can load it
execution_info_path = os.path.join(
execution_info.tmp_dir, "zenml_execution_info.pb"
)
_write_execution_info(execution_info, path=execution_info_path)
step_name_in_pipeline = self._get_step_name_in_pipeline(execution_info)
entrypoint_command = (
StepOperatorEntrypointConfiguration.get_entrypoint_command()
+ StepOperatorEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name_in_pipeline,
execution_info_path=execution_info_path,
)
)
logger.info(
"Using step operator `%s` to run step `%s`.",
step_operator.name,
step_name_in_pipeline,
)
step_run_info = StepRunInfo(
config=step.config,
pipeline=pipeline_config,
run_name=execution_info.pipeline_run_id,
)
step_operator.launch(
info=step_run_info,
entrypoint_command=entrypoint_command,
)
return _read_executor_output(execution_info.execution_output_uri)
run_executor(self, execution_info)
Invokes the executor with inputs provided by the Launcher.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
execution_info |
ExecutionInfo |
Necessary information to run the executor. |
required |
Returns:
Type | Description |
---|---|
ExecutorOutput |
The executor output. |
Source code in zenml/step_operators/step_executor_operator.py
def run_executor(
self,
execution_info: data_types.ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
"""Invokes the executor with inputs provided by the Launcher.
Args:
execution_info: Necessary information to run the executor.
Returns:
The executor output.
"""
# Pretty sure these attributes will always be not None, assert here so
# mypy doesn't complain
assert execution_info.pipeline_node
assert execution_info.pipeline_info
assert execution_info.pipeline_run_id
assert execution_info.tmp_dir
assert execution_info.execution_output_uri
step = proto_utils.get_step(pipeline_node=execution_info.pipeline_node)
pipeline_config = proto_utils.get_pipeline_config(
pipeline_node=execution_info.pipeline_node
)
assert step.config.step_operator
stack = Client().active_stack
step_operator = self._get_step_operator(
stack=stack, step_operator_name=step.config.step_operator
)
# Write the execution info to a temporary directory inside the artifact
# store so the step operator entrypoint can load it
execution_info_path = os.path.join(
execution_info.tmp_dir, "zenml_execution_info.pb"
)
_write_execution_info(execution_info, path=execution_info_path)
step_name_in_pipeline = self._get_step_name_in_pipeline(execution_info)
entrypoint_command = (
StepOperatorEntrypointConfiguration.get_entrypoint_command()
+ StepOperatorEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name_in_pipeline,
execution_info_path=execution_info_path,
)
)
logger.info(
"Using step operator `%s` to run step `%s`.",
step_operator.name,
step_name_in_pipeline,
)
step_run_info = StepRunInfo(
config=step.config,
pipeline=pipeline_config,
run_name=execution_info.pipeline_run_id,
)
step_operator.launch(
info=step_run_info,
entrypoint_command=entrypoint_command,
)
return _read_executor_output(execution_info.execution_output_uri)
step_operator_entrypoint_configuration
Abstract base class for entrypoint configurations that run a single step.
StepOperatorEntrypointConfiguration (StepEntrypointConfiguration)
Base class for step operator entrypoint configurations.
Source code in zenml/step_operators/step_operator_entrypoint_configuration.py
class StepOperatorEntrypointConfiguration(StepEntrypointConfiguration):
"""Base class for step operator entrypoint configurations."""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the path to the
execution info.
"""
return super().get_entrypoint_options() | {EXECUTION_INFO_PATH_OPTION}
@classmethod
def get_entrypoint_arguments(
cls,
**kwargs: Any,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the execution info path.
Returns:
The superclass arguments as well as arguments for the path to the
execution info.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{EXECUTION_INFO_PATH_OPTION}",
kwargs[EXECUTION_INFO_PATH_OPTION],
]
def _run_step(
self,
step: "Step",
deployment: "PipelineDeployment",
) -> Optional[data_types.ExecutionInfo]:
"""Runs a single step.
Args:
step: The step to run.
deployment: The deployment configuration.
Raises:
RuntimeError: If the step executor class does not exist.
Returns:
Step execution info.
"""
# Make sure the artifact store is loaded before we load the execution
# info
stack = Client().active_stack
execution_info_path = self.entrypoint_args[EXECUTION_INFO_PATH_OPTION]
execution_info = self._load_execution_info(execution_info_path)
executor_class = step_utils.get_executor_class(step.config.name)
if not executor_class:
raise RuntimeError(
f"Unable to find executor class for step {step.config.name}."
)
executor = self._configure_executor(
executor_class=executor_class, execution_info=execution_info
)
stack.orchestrator._ensure_artifact_classes_loaded(step.config)
step_run_info = StepRunInfo(
config=step.config,
pipeline=deployment.pipeline,
run_name=execution_info.pipeline_run_id,
)
stack.prepare_step_run(info=step_run_info)
step_failed = False
try:
run_with_executor(execution_info=execution_info, executor=executor)
except Exception:
step_failed = True
finally:
stack.cleanup_step_run(info=step_run_info, step_failed=step_failed)
return execution_info
@staticmethod
def _load_execution_info(execution_info_path: str) -> ExecutionInfo:
"""Loads the execution info from the given path.
Args:
execution_info_path: Path to the execution info file.
Returns:
Execution info.
"""
with fileio.open(execution_info_path, "rb") as f:
execution_info_proto = ExecutionInvocation.FromString(f.read())
return ExecutionInfo.from_proto(execution_info_proto)
@staticmethod
def _configure_executor(
executor_class: Type[BaseExecutor], execution_info: ExecutionInfo
) -> BaseExecutor:
"""Creates and configures an executor instance.
Args:
executor_class: The class of the executor instance.
execution_info: Execution info for the executor.
Returns:
A configured executor instance.
"""
context = BaseExecutor.Context(
tmp_dir=execution_info.tmp_dir,
unique_id=str(execution_info.execution_id),
executor_output_uri=execution_info.execution_output_uri,
stateful_working_dir=execution_info.stateful_working_dir,
pipeline_node=execution_info.pipeline_node,
pipeline_info=execution_info.pipeline_info,
pipeline_run_id=execution_info.pipeline_run_id,
)
return executor_class(context=context)
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs, must include the execution info path. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the path to the execution info. |
Source code in zenml/step_operators/step_operator_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
cls,
**kwargs: Any,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
Args:
**kwargs: Kwargs, must include the execution info path.
Returns:
The superclass arguments as well as arguments for the path to the
execution info.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{EXECUTION_INFO_PATH_OPTION}",
kwargs[EXECUTION_INFO_PATH_OPTION],
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the path to the execution info. |
Source code in zenml/step_operators/step_operator_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the path to the
execution info.
"""
return super().get_entrypoint_options() | {EXECUTION_INFO_PATH_OPTION}