Skip to content

Orchestrators

zenml.orchestrators special

An orchestrator is a special kind of backend that manages the running of each step of the pipeline. Orchestrators administer the actual pipeline runs. You can think of it as the 'root' of any pipeline job that you run during your experimentation.

ZenML supports a local orchestrator out of the box which allows you to run your pipelines in a local environment. We also support using Apache Airflow as the orchestrator to handle the steps of your pipeline.

base_orchestrator

BaseOrchestrator (StackComponent, ABC) pydantic-model

Base class for all ZenML orchestrators.

Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestrator(StackComponent, ABC):
    """Base class for all ZenML orchestrators."""

    @property
    def type(self) -> StackComponentType:
        """The component type."""
        return StackComponentType.ORCHESTRATOR

    @property
    @abstractmethod
    def flavor(self) -> OrchestratorFlavor:
        """The orchestrator flavor."""

    @abstractmethod
    def run_pipeline(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Runs a pipeline.

        Args:
            pipeline: The pipeline to run.
            stack: The stack on which the pipeline is run.
            runtime_configuration: Runtime configuration of the pipeline run.
        """
flavor: OrchestratorFlavor property readonly

The orchestrator flavor.

type: StackComponentType property readonly

The component type.

run_pipeline(self, pipeline, stack, runtime_configuration)

Runs a pipeline.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to run.

required
stack Stack

The stack on which the pipeline is run.

required
runtime_configuration RuntimeConfiguration

Runtime configuration of the pipeline run.

required
Source code in zenml/orchestrators/base_orchestrator.py
@abstractmethod
def run_pipeline(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Runs a pipeline.

    Args:
        pipeline: The pipeline to run.
        stack: The stack on which the pipeline is run.
        runtime_configuration: Runtime configuration of the pipeline run.
    """

context_utils

add_context_to_node(pipeline_node, type_, name, properties)

Add a new context to a TFX protobuf pipeline node.

Parameters:

Name Type Description Default
pipeline_node pipeline_pb2.PipelineNode

A tfx protobuf pipeline node

required
type_ str

The type name for the context to be added

required
name str

Unique key for the context

required
properties Dict[str, str]

dictionary of strings as properties of the context

required
Source code in zenml/orchestrators/context_utils.py
def add_context_to_node(
    pipeline_node: "pipeline_pb2.PipelineNode",  # type: ignore[valid-type]
    type_: str,
    name: str,
    properties: Dict[str, str],
) -> None:
    """
    Add a new context to a TFX protobuf pipeline node.

    Args:
        pipeline_node: A tfx protobuf pipeline node
        type_: The type name for the context to be added
        name: Unique key for the context
        properties: dictionary of strings as properties of the context
    """
    # Add a new context to the pipeline
    context: "pipeline_pb2.ContextSpec" = (  # type: ignore[valid-type]
        pipeline_node.contexts.contexts.add()  # type: ignore[attr-defined]
    )
    # Adding the type of context
    context.type.name = type_  # type: ignore[attr-defined]
    # Setting the name of the context
    context.name.field_value.string_value = name  # type: ignore[attr-defined]
    # Setting the properties of the context depending on attribute type
    for key, value in properties.items():
        c_property = context.properties[key]  # type:ignore[attr-defined]
        c_property.field_value.string_value = value

add_runtime_configuration_to_node(pipeline_node, runtime_config)

Add the runtime configuration of a pipeline run to a protobuf pipeline node.

Parameters:

Name Type Description Default
pipeline_node pipeline_pb2.PipelineNode

a tfx protobuf pipeline node

required
runtime_config RuntimeConfiguration

a ZenML RuntimeConfiguration

required
Source code in zenml/orchestrators/context_utils.py
def add_runtime_configuration_to_node(
    pipeline_node: "pipeline_pb2.PipelineNode",  # type: ignore[valid-type]
    runtime_config: RuntimeConfiguration,
) -> None:
    """
    Add the runtime configuration of a pipeline run to a protobuf pipeline node.

    Args:
        pipeline_node: a tfx protobuf pipeline node
        runtime_config: a ZenML RuntimeConfiguration
    """
    skip_errors: bool = runtime_config.get(
        "ignore_unserializable_fields", False
    )

    # Determine the name of the context
    def _name(obj: "BaseModel") -> str:
        """Compute a unique context name for a pydantic BaseModel."""
        try:
            return str(hash(obj.json(sort_keys=True)))
        except TypeError as e:
            class_name = obj.__class__.__name__
            logging.info(
                "Cannot convert %s to json, generating uuid instead. Error: %s",
                class_name,
                e,
            )
            return f"{class_name}_{uuid.uuid1()}"

    # iterate over all attributes of runtime context, serializing all pydantic
    # objects to node context.
    for key, obj in runtime_config.items():
        if isinstance(obj, BaseModel):
            logger.debug("Adding %s to context", key)
            add_context_to_node(
                pipeline_node,
                type_=obj.__repr_name__().lower(),
                name=_name(obj),
                properties=serialize_pydantic_object(
                    obj, skip_errors=skip_errors
                ),
            )

serialize_pydantic_object(obj, *, skip_errors=False)

Convert a pydantic object to a dict of strings

Source code in zenml/orchestrators/context_utils.py
def serialize_pydantic_object(
    obj: BaseModel, *, skip_errors: bool = False
) -> Dict[str, str]:
    """Convert a pydantic object to a dict of strings"""

    class PydanticEncoder(json.JSONEncoder):
        def default(self, o: Any) -> Any:
            try:
                return cast(Callable[[Any], str], obj.__json_encoder__)(o)
            except TypeError:
                return super().default(o)

    def _inner_generator(
        dictionary: Dict[str, Any]
    ) -> Iterator[Tuple[str, str]]:
        """Itemwise serialize each element in a dictionary."""
        for key, item in dictionary.items():
            try:
                yield key, json.dumps(item, cls=PydanticEncoder)
            except TypeError as e:
                if skip_errors:
                    logging.info(
                        "Skipping adding field '%s' to metadata context as "
                        "it cannot be serialized due to %s.",
                        key,
                        e,
                    )
                else:
                    raise TypeError(
                        f"Invalid type {type(item)} for key {key} can not be "
                        "serialized."
                    ) from e

    return {key: value for key, value in _inner_generator(obj.dict())}

local special

local_orchestrator

LocalOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines locally.

Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines locally."""

    supports_local_execution = True
    supports_remote_execution = False

    @property
    def flavor(self) -> OrchestratorFlavor:
        """The orchestrator flavor."""
        return OrchestratorFlavor.LOCAL

    def run_pipeline(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Runs a pipeline locally"""

        tfx_pipeline: TfxPipeline = create_tfx_pipeline(pipeline, stack=stack)

        if runtime_configuration is None:
            runtime_configuration = RuntimeConfiguration()

        if runtime_configuration.schedule:
            logger.warning(
                "Local Orchestrator currently does not support the"
                "use of schedules. The `schedule` will be ignored "
                "and the pipeline will be run directly"
            )

        for component in tfx_pipeline.components:
            if isinstance(component, base_component.BaseComponent):
                component._resolve_pip_dependencies(
                    tfx_pipeline.pipeline_info.pipeline_root
                )

        pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)  # type: ignore[valid-type]

        # Substitute the runtime parameter to be a concrete run_id
        runtime_parameter_utils.substitute_runtime_parameter(
            pb2_pipeline,
            {
                PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
            },
        )

        deployment_config = runner_utils.extract_local_deployment_config(
            pb2_pipeline
        )
        connection_config = (
            Repository().active_stack.metadata_store.get_tfx_metadata_config()
        )

        logger.debug(f"Using deployment config:\n {deployment_config}")
        logger.debug(f"Using connection config:\n {connection_config}")

        # Run each component. Note that the pipeline.components list is in
        # topological order.
        for node in pb2_pipeline.nodes:  # type: ignore[attr-defined]
            pipeline_node: PipelineNode = node.pipeline_node  # type: ignore[valid-type]

            # fill out that context
            context_utils.add_context_to_node(
                pipeline_node,
                type_=MetadataContextTypes.STACK.value,
                name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
                properties=stack.dict(),
            )

            # Add all pydantic objects from runtime_configuration to the context
            context_utils.add_runtime_configuration_to_node(
                pipeline_node, runtime_configuration
            )

            node_id = pipeline_node.node_info.id  # type:ignore[attr-defined]
            executor_spec = runner_utils.extract_executor_spec(
                deployment_config, node_id
            )
            custom_driver_spec = runner_utils.extract_custom_driver_spec(
                deployment_config, node_id
            )

            p_info = pb2_pipeline.pipeline_info  # type:ignore[attr-defined]
            r_spec = pb2_pipeline.runtime_spec  # type:ignore[attr-defined]

            component_launcher = launcher.Launcher(
                pipeline_node=pipeline_node,
                mlmd_connection=metadata.Metadata(connection_config),
                pipeline_info=p_info,
                pipeline_runtime_spec=r_spec,
                executor_spec=executor_spec,
                custom_driver_spec=custom_driver_spec,
            )
            execute_step(component_launcher)
flavor: OrchestratorFlavor property readonly

The orchestrator flavor.

run_pipeline(self, pipeline, stack, runtime_configuration)

Runs a pipeline locally

Source code in zenml/orchestrators/local/local_orchestrator.py
def run_pipeline(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Runs a pipeline locally"""

    tfx_pipeline: TfxPipeline = create_tfx_pipeline(pipeline, stack=stack)

    if runtime_configuration is None:
        runtime_configuration = RuntimeConfiguration()

    if runtime_configuration.schedule:
        logger.warning(
            "Local Orchestrator currently does not support the"
            "use of schedules. The `schedule` will be ignored "
            "and the pipeline will be run directly"
        )

    for component in tfx_pipeline.components:
        if isinstance(component, base_component.BaseComponent):
            component._resolve_pip_dependencies(
                tfx_pipeline.pipeline_info.pipeline_root
            )

    pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)  # type: ignore[valid-type]

    # Substitute the runtime parameter to be a concrete run_id
    runtime_parameter_utils.substitute_runtime_parameter(
        pb2_pipeline,
        {
            PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
        },
    )

    deployment_config = runner_utils.extract_local_deployment_config(
        pb2_pipeline
    )
    connection_config = (
        Repository().active_stack.metadata_store.get_tfx_metadata_config()
    )

    logger.debug(f"Using deployment config:\n {deployment_config}")
    logger.debug(f"Using connection config:\n {connection_config}")

    # Run each component. Note that the pipeline.components list is in
    # topological order.
    for node in pb2_pipeline.nodes:  # type: ignore[attr-defined]
        pipeline_node: PipelineNode = node.pipeline_node  # type: ignore[valid-type]

        # fill out that context
        context_utils.add_context_to_node(
            pipeline_node,
            type_=MetadataContextTypes.STACK.value,
            name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
            properties=stack.dict(),
        )

        # Add all pydantic objects from runtime_configuration to the context
        context_utils.add_runtime_configuration_to_node(
            pipeline_node, runtime_configuration
        )

        node_id = pipeline_node.node_info.id  # type:ignore[attr-defined]
        executor_spec = runner_utils.extract_executor_spec(
            deployment_config, node_id
        )
        custom_driver_spec = runner_utils.extract_custom_driver_spec(
            deployment_config, node_id
        )

        p_info = pb2_pipeline.pipeline_info  # type:ignore[attr-defined]
        r_spec = pb2_pipeline.runtime_spec  # type:ignore[attr-defined]

        component_launcher = launcher.Launcher(
            pipeline_node=pipeline_node,
            mlmd_connection=metadata.Metadata(connection_config),
            pipeline_info=p_info,
            pipeline_runtime_spec=r_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
        )
        execute_step(component_launcher)

utils

create_tfx_pipeline(zenml_pipeline, stack)

Creates a tfx pipeline from a ZenML pipeline.

Source code in zenml/orchestrators/utils.py
def create_tfx_pipeline(
    zenml_pipeline: "BasePipeline", stack: "Stack"
) -> tfx_pipeline.Pipeline:
    """Creates a tfx pipeline from a ZenML pipeline."""
    # Connect the inputs/outputs of all steps in the pipeline
    zenml_pipeline.connect(**zenml_pipeline.steps)

    tfx_components = [step.component for step in zenml_pipeline.steps.values()]

    artifact_store = stack.artifact_store
    metadata_store = stack.metadata_store

    return tfx_pipeline.Pipeline(
        pipeline_name=zenml_pipeline.name,
        components=tfx_components,  # type: ignore[arg-type]
        pipeline_root=artifact_store.path,
        metadata_connection_config=metadata_store.get_tfx_metadata_config(),
        enable_cache=zenml_pipeline.enable_cache,
    )

execute_step(tfx_launcher)

Executes a tfx component.

Parameters:

Name Type Description Default
tfx_launcher Launcher

A tfx launcher to execute the component.

required

Returns:

Type Description
Optional[tfx.orchestration.portable.data_types.ExecutionInfo]

Optional execution info returned by the launcher.

Source code in zenml/orchestrators/utils.py
def execute_step(
    tfx_launcher: launcher.Launcher,
) -> Optional[data_types.ExecutionInfo]:
    """Executes a tfx component.

    Args:
        tfx_launcher: A tfx launcher to execute the component.

    Returns:
        Optional execution info returned by the launcher.
    """
    step_name = tfx_launcher._pipeline_node.node_info.id  # type: ignore[attr-defined] # noqa
    start_time = time.time()
    logger.info(f"Step `{step_name}` has started.")
    try:
        execution_info = tfx_launcher.launch()
    except RuntimeError as e:
        if "execution has already succeeded" in str(e):
            # Hacky workaround to catch the error that a pipeline run with
            # this name already exists. Raise an error with a more descriptive
            # message instead.
            raise DuplicateRunNameError()
        else:
            raise

    run_duration = time.time() - start_time
    logger.info(
        "Step `%s` has finished in %s.",
        step_name,
        string_utils.get_human_readable_time(run_duration),
    )
    return execution_info