Orchestrators
zenml.orchestrators
special
Orchestrator
An orchestrator is a special kind of backend that manages the running of each step of the pipeline. Orchestrators administer the actual pipeline runs. You can think of it as the 'root' of any pipeline job that you run during your experimentation.
ZenML supports a local orchestrator out of the box which allows you to run your pipelines in a local environment. We also support using Apache Airflow as the orchestrator to handle the steps of your pipeline.
base_orchestrator
BaseOrchestrator (StackComponent, ABC)
pydantic-model
Base class for all ZenML orchestrators.
Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestrator(StackComponent, ABC):
"""Base class for all ZenML orchestrators."""
# Class Configuration
TYPE: ClassVar[StackComponentType] = StackComponentType.ORCHESTRATOR
@abstractmethod
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline.
Args:
pipeline: The pipeline to run.
stack: The stack on which the pipeline is run.
runtime_configuration: Runtime configuration of the pipeline run.
"""
run_pipeline(self, pipeline, stack, runtime_configuration)
Runs a pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
The pipeline to run. |
required |
stack |
Stack |
The stack on which the pipeline is run. |
required |
runtime_configuration |
RuntimeConfiguration |
Runtime configuration of the pipeline run. |
required |
Source code in zenml/orchestrators/base_orchestrator.py
@abstractmethod
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline.
Args:
pipeline: The pipeline to run.
stack: The stack on which the pipeline is run.
runtime_configuration: Runtime configuration of the pipeline run.
"""
context_utils
add_context_to_node(pipeline_node, type_, name, properties)
Add a new context to a TFX protobuf pipeline node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_node |
pipeline_pb2.PipelineNode |
A tfx protobuf pipeline node |
required |
type_ |
str |
The type name for the context to be added |
required |
name |
str |
Unique key for the context |
required |
properties |
Dict[str, str] |
dictionary of strings as properties of the context |
required |
Source code in zenml/orchestrators/context_utils.py
def add_context_to_node(
pipeline_node: "pipeline_pb2.PipelineNode",
type_: str,
name: str,
properties: Dict[str, str],
) -> None:
"""
Add a new context to a TFX protobuf pipeline node.
Args:
pipeline_node: A tfx protobuf pipeline node
type_: The type name for the context to be added
name: Unique key for the context
properties: dictionary of strings as properties of the context
"""
# Add a new context to the pipeline
context: "pipeline_pb2.ContextSpec" = pipeline_node.contexts.contexts.add()
# Adding the type of context
context.type.name = type_
# Setting the name of the context
context.name.field_value.string_value = name
# Setting the properties of the context depending on attribute type
for key, value in properties.items():
c_property = context.properties[key]
c_property.field_value.string_value = value
add_runtime_configuration_to_node(pipeline_node, runtime_config)
Add the runtime configuration of a pipeline run to a protobuf pipeline node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_node |
pipeline_pb2.PipelineNode |
a tfx protobuf pipeline node |
required |
runtime_config |
RuntimeConfiguration |
a ZenML RuntimeConfiguration |
required |
Source code in zenml/orchestrators/context_utils.py
def add_runtime_configuration_to_node(
pipeline_node: "pipeline_pb2.PipelineNode",
runtime_config: RuntimeConfiguration,
) -> None:
"""
Add the runtime configuration of a pipeline run to a protobuf pipeline node.
Args:
pipeline_node: a tfx protobuf pipeline node
runtime_config: a ZenML RuntimeConfiguration
"""
skip_errors: bool = runtime_config.get(
"ignore_unserializable_fields", False
)
# Determine the name of the context
def _name(obj: "BaseModel") -> str:
"""Compute a unique context name for a pydantic BaseModel."""
try:
return str(hash(obj.json(sort_keys=True)))
except TypeError as e:
class_name = obj.__class__.__name__
logging.info(
"Cannot convert %s to json, generating uuid instead. Error: %s",
class_name,
e,
)
return f"{class_name}_{uuid.uuid1()}"
# iterate over all attributes of runtime context, serializing all pydantic
# objects to node context.
for key, obj in runtime_config.items():
if isinstance(obj, BaseModel):
logger.debug("Adding %s to context", key)
add_context_to_node(
pipeline_node,
type_=obj.__repr_name__().lower(),
name=_name(obj),
properties=serialize_pydantic_object(
obj, skip_errors=skip_errors
),
)
serialize_pydantic_object(obj, *, skip_errors=False)
Convert a pydantic object to a dict of strings
Source code in zenml/orchestrators/context_utils.py
def serialize_pydantic_object(
obj: BaseModel, *, skip_errors: bool = False
) -> Dict[str, str]:
"""Convert a pydantic object to a dict of strings"""
class PydanticEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
try:
return cast(Callable[[Any], str], obj.__json_encoder__)(o)
except TypeError:
return super().default(o)
def _inner_generator(
dictionary: Dict[str, Any]
) -> Iterator[Tuple[str, str]]:
"""Itemwise serialize each element in a dictionary."""
for key, item in dictionary.items():
try:
yield key, json.dumps(item, cls=PydanticEncoder)
except TypeError as e:
if skip_errors:
logging.info(
"Skipping adding field '%s' to metadata context as "
"it cannot be serialized due to %s.",
key,
e,
)
else:
raise TypeError(
f"Invalid type {type(item)} for key {key} can not be "
"serialized."
) from e
return {key: value for key, value in _inner_generator(obj.dict())}
local
special
local_orchestrator
LocalOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines locally.
Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines locally."""
# Class Configuration
FLAVOR: ClassVar[str] = "local"
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline locally"""
tfx_pipeline: TfxPipeline = create_tfx_pipeline(pipeline, stack=stack)
if runtime_configuration is None:
runtime_configuration = RuntimeConfiguration()
if runtime_configuration.schedule:
logger.warning(
"Local Orchestrator currently does not support the"
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run directly"
)
pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
if not isinstance(pipeline_root, str):
raise TypeError(
"TFX Pipeline root may not be a Placeholder, "
"but must be a specific string."
)
for component in tfx_pipeline.components:
if isinstance(component, base_component.BaseComponent):
component._resolve_pip_dependencies(pipeline_root)
pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)
# Substitute the runtime parameter to be a concrete run_id
runtime_parameter_utils.substitute_runtime_parameter(
pb2_pipeline,
{
PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
},
)
deployment_config = runner_utils.extract_local_deployment_config(
pb2_pipeline
)
connection_config = (
Repository().active_stack.metadata_store.get_tfx_metadata_config()
)
logger.debug(f"Using deployment config:\n {deployment_config}")
logger.debug(f"Using connection config:\n {connection_config}")
# Run each component. Note that the pipeline.components list is in
# topological order.
for node in pb2_pipeline.nodes:
pipeline_node: PipelineNode = node.pipeline_node
# fill out that context
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.STACK.value,
name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
properties=stack.dict(),
)
# Add all pydantic objects from runtime_configuration to the context
context_utils.add_runtime_configuration_to_node(
pipeline_node, runtime_configuration
)
# Add pipeline requirements as a context
requirements = " ".join(sorted(pipeline.requirements))
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
name=str(hash(requirements)),
properties={"pipeline_requirements": requirements},
)
node_id = pipeline_node.node_info.id
executor_spec = runner_utils.extract_executor_spec(
deployment_config, node_id
)
custom_driver_spec = runner_utils.extract_custom_driver_spec(
deployment_config, node_id
)
p_info = pb2_pipeline.pipeline_info
r_spec = pb2_pipeline.runtime_spec
# set custom executor operator to allow custom execution logic for
# each step
step = get_step_for_node(
pipeline_node, steps=list(pipeline.steps.values())
)
custom_executor_operators = {
executable_spec_pb2.PythonClassExecutableSpec: step.executor_operator
}
component_launcher = launcher.Launcher(
pipeline_node=pipeline_node,
mlmd_connection=metadata.Metadata(connection_config),
pipeline_info=p_info,
pipeline_runtime_spec=r_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
execute_step(component_launcher)
run_pipeline(self, pipeline, stack, runtime_configuration)
Runs a pipeline locally
Source code in zenml/orchestrators/local/local_orchestrator.py
def run_pipeline(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Runs a pipeline locally"""
tfx_pipeline: TfxPipeline = create_tfx_pipeline(pipeline, stack=stack)
if runtime_configuration is None:
runtime_configuration = RuntimeConfiguration()
if runtime_configuration.schedule:
logger.warning(
"Local Orchestrator currently does not support the"
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run directly"
)
pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
if not isinstance(pipeline_root, str):
raise TypeError(
"TFX Pipeline root may not be a Placeholder, "
"but must be a specific string."
)
for component in tfx_pipeline.components:
if isinstance(component, base_component.BaseComponent):
component._resolve_pip_dependencies(pipeline_root)
pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)
# Substitute the runtime parameter to be a concrete run_id
runtime_parameter_utils.substitute_runtime_parameter(
pb2_pipeline,
{
PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
},
)
deployment_config = runner_utils.extract_local_deployment_config(
pb2_pipeline
)
connection_config = (
Repository().active_stack.metadata_store.get_tfx_metadata_config()
)
logger.debug(f"Using deployment config:\n {deployment_config}")
logger.debug(f"Using connection config:\n {connection_config}")
# Run each component. Note that the pipeline.components list is in
# topological order.
for node in pb2_pipeline.nodes:
pipeline_node: PipelineNode = node.pipeline_node
# fill out that context
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.STACK.value,
name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
properties=stack.dict(),
)
# Add all pydantic objects from runtime_configuration to the context
context_utils.add_runtime_configuration_to_node(
pipeline_node, runtime_configuration
)
# Add pipeline requirements as a context
requirements = " ".join(sorted(pipeline.requirements))
context_utils.add_context_to_node(
pipeline_node,
type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
name=str(hash(requirements)),
properties={"pipeline_requirements": requirements},
)
node_id = pipeline_node.node_info.id
executor_spec = runner_utils.extract_executor_spec(
deployment_config, node_id
)
custom_driver_spec = runner_utils.extract_custom_driver_spec(
deployment_config, node_id
)
p_info = pb2_pipeline.pipeline_info
r_spec = pb2_pipeline.runtime_spec
# set custom executor operator to allow custom execution logic for
# each step
step = get_step_for_node(
pipeline_node, steps=list(pipeline.steps.values())
)
custom_executor_operators = {
executable_spec_pb2.PythonClassExecutableSpec: step.executor_operator
}
component_launcher = launcher.Launcher(
pipeline_node=pipeline_node,
mlmd_connection=metadata.Metadata(connection_config),
pipeline_info=p_info,
pipeline_runtime_spec=r_spec,
executor_spec=executor_spec,
custom_driver_spec=custom_driver_spec,
custom_executor_operators=custom_executor_operators,
)
execute_step(component_launcher)
utils
create_tfx_pipeline(zenml_pipeline, stack)
Creates a tfx pipeline from a ZenML pipeline.
Source code in zenml/orchestrators/utils.py
def create_tfx_pipeline(
zenml_pipeline: "BasePipeline", stack: "Stack"
) -> tfx_pipeline.Pipeline:
"""Creates a tfx pipeline from a ZenML pipeline."""
# Connect the inputs/outputs of all steps in the pipeline
zenml_pipeline.connect(**zenml_pipeline.steps)
tfx_components = [step.component for step in zenml_pipeline.steps.values()]
artifact_store = stack.artifact_store
metadata_store = stack.metadata_store
return tfx_pipeline.Pipeline(
pipeline_name=zenml_pipeline.name,
components=tfx_components, # type: ignore[arg-type]
pipeline_root=artifact_store.path,
metadata_connection_config=metadata_store.get_tfx_metadata_config(),
enable_cache=zenml_pipeline.enable_cache,
)
execute_step(tfx_launcher)
Executes a tfx component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tfx_launcher |
Launcher |
A tfx launcher to execute the component. |
required |
Returns:
Type | Description |
---|---|
Optional[tfx.orchestration.portable.data_types.ExecutionInfo] |
Optional execution info returned by the launcher. |
Source code in zenml/orchestrators/utils.py
def execute_step(
tfx_launcher: launcher.Launcher,
) -> Optional[data_types.ExecutionInfo]:
"""Executes a tfx component.
Args:
tfx_launcher: A tfx launcher to execute the component.
Returns:
Optional execution info returned by the launcher.
"""
step_name_param = (
INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
)
pipeline_step_name = tfx_launcher._pipeline_node.node_info.id
start_time = time.time()
logger.info(f"Step `{pipeline_step_name}` has started.")
try:
execution_info = tfx_launcher.launch()
if execution_info and get_cache_status(execution_info):
if execution_info.exec_properties:
step_name = json.loads(
execution_info.exec_properties[step_name_param]
)
logger.info(
f"Using cached version of `{pipeline_step_name}` "
f"[`{step_name}`].",
)
else:
logger.error(
f"No execution properties found for step "
f"`{pipeline_step_name}`."
)
except RuntimeError as e:
if "execution has already succeeded" in str(e):
# Hacky workaround to catch the error that a pipeline run with
# this name already exists. Raise an error with a more descriptive
# message instead.
raise DuplicateRunNameError()
else:
raise
run_duration = time.time() - start_time
logger.info(
f"Step `{pipeline_step_name}` has finished in "
f"{string_utils.get_human_readable_time(run_duration)}."
)
return execution_info
get_cache_status(execution_info)
Returns the caching status of a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
execution_info |
ExecutionInfo |
The execution info of a |
required |
Exceptions:
Type | Description |
---|---|
AttributeError |
If the execution info is |
KeyError |
If no pipeline info is found in the |
Returns:
Type | Description |
---|---|
bool |
The caching status of a |
Source code in zenml/orchestrators/utils.py
def get_cache_status(
execution_info: data_types.ExecutionInfo,
) -> bool:
"""Returns the caching status of a step.
Args:
execution_info: The execution info of a `tfx` step.
Raises:
AttributeError: If the execution info is `None`.
KeyError: If no pipeline info is found in the `execution_info`.
Returns:
The caching status of a `tfx` step as a boolean value.
"""
if execution_info is None:
logger.warning("No execution info found when checking cache status.")
return False
status = False
repository = Repository()
# TODO [ENG-706]: Get the current running stack instead of just the active
# stack
active_stack = repository.active_stack
if not active_stack:
raise RuntimeError(
"No active stack is configured for the repository. Run "
"`zenml stack set STACK_NAME` to update the active stack."
)
metadata_store = active_stack.metadata_store
step_name_param = (
INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
)
step_name = json.loads(execution_info.exec_properties[step_name_param])
if execution_info.pipeline_info:
pipeline_name = execution_info.pipeline_info.id
else:
raise KeyError(f"No pipeline info found for step `{step_name}`.")
pipeline_run_name = cast(str, execution_info.pipeline_run_id)
pipeline = metadata_store.get_pipeline(pipeline_name)
if pipeline is None:
logger.error(f"Pipeline {pipeline_name} not found in Metadata Store.")
else:
status = (
pipeline.get_run(pipeline_run_name).get_step(step_name).is_cached
)
return status
get_step_for_node(node, steps)
Finds the matching step for a tfx pipeline node.
Source code in zenml/orchestrators/utils.py
def get_step_for_node(node: PipelineNode, steps: List[BaseStep]) -> BaseStep:
"""Finds the matching step for a tfx pipeline node."""
step_name = node.node_info.id
try:
return next(step for step in steps if step.name == step_name)
except StopIteration:
raise RuntimeError(f"Unable to find step with name '{step_name}'.")