Orchestrators
zenml.orchestrators
special
Initialization for ZenML orchestrators.
An orchestrator is a special kind of backend that manages the running of each step of the pipeline. Orchestrators administer the actual pipeline runs. You can think of it as the 'root' of any pipeline job that you run during your experimentation.
ZenML supports a local orchestrator out of the box which allows you to run your pipelines in a local environment. We also support using Apache Airflow as the orchestrator to handle the steps of your pipeline.
base_orchestrator
Base orchestrator class.
BaseOrchestrator (StackComponent, ABC)
Base class for all orchestrators.
In order to implement an orchestrator you will need to subclass from this class.
How it works:
The run(...)
method is the entrypoint that is executed when the
pipeline's run method is called within the user code
(pipeline_instance.run(...)
).
This method will do some internal preparation and then call the
prepare_or_run_pipeline(...)
method. BaseOrchestrator subclasses must
implement this method and either run the pipeline steps directly or deploy
the pipeline to some remote infrastructure.
Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestrator(StackComponent, ABC):
"""Base class for all orchestrators.
In order to implement an orchestrator you will need to subclass from this
class.
How it works:
-------------
The `run(...)` method is the entrypoint that is executed when the
pipeline's run method is called within the user code
(`pipeline_instance.run(...)`).
This method will do some internal preparation and then call the
`prepare_or_run_pipeline(...)` method. BaseOrchestrator subclasses must
implement this method and either run the pipeline steps directly or deploy
the pipeline to some remote infrastructure.
"""
_active_deployment: Optional["PipelineDeploymentResponse"] = None
@property
def config(self) -> BaseOrchestratorConfig:
"""Returns the `BaseOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(BaseOrchestratorConfig, self._config)
@abstractmethod
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
"""
@abstractmethod
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""The method needs to be implemented by the respective orchestrator.
Depending on the type of orchestrator you'll have to perform slightly
different operations.
Simple Case:
------------
The Steps are run directly from within the same environment in which
the orchestrator code is executed. In this case you will need to
deal with implementation-specific runtime configurations (like the
schedule) and then iterate through the steps and finally call
`self.run_step(...)` to execute each step.
Advanced Case:
--------------
Most orchestrators will not run the steps directly. Instead, they
build some intermediate representation of the pipeline that is then
used to create and run the pipeline and its steps on the target
environment. For such orchestrators this method will have to build
this representation and deploy it.
Regardless of the implementation details, the orchestrator will need
to run each step in the target environment. For this the
`self.run_step(...)` method should be used.
The easiest way to make this work is by using an entrypoint
configuration to run single steps (`zenml.entrypoints.step_entrypoint_configuration.StepEntrypointConfiguration`)
or entire pipelines (`zenml.entrypoints.pipeline_entrypoint_configuration.PipelineEntrypointConfiguration`).
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment. These don't need to be set if running locally.
Returns:
The optional return value from this method will be returned by the
`pipeline_instance.run()` call when someone is running a pipeline.
"""
def run(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
) -> Any:
"""Runs a pipeline on a stack.
Args:
deployment: The pipeline deployment.
stack: The stack on which to run the pipeline.
Returns:
Orchestrator-specific return value.
"""
self._prepare_run(deployment=deployment)
environment = get_config_environment_vars(deployment=deployment)
try:
result = self.prepare_or_run_pipeline(
deployment=deployment, stack=stack, environment=environment
)
finally:
self._cleanup_run()
return result
def run_step(self, step: "Step") -> None:
"""Runs the given step.
Args:
step: The step to run.
"""
assert self._active_deployment
launcher = StepLauncher(
deployment=self._active_deployment,
step=step,
orchestrator_run_id=self.get_orchestrator_run_id(),
)
launcher.launch()
@staticmethod
def requires_resources_in_orchestration_environment(
step: "Step",
) -> bool:
"""Checks if the orchestrator should run this step on special resources.
Args:
step: The step that will be checked.
Returns:
True if the step requires special resources in the orchestration
environment, False otherwise.
"""
# If the step requires custom resources and doesn't run with a step
# operator, it would need these requirements in the orchestrator
# environment
if step.config.step_operator:
return False
return not step.config.resource_settings.empty
def _prepare_run(self, deployment: "PipelineDeploymentResponse") -> None:
"""Prepares a run.
Args:
deployment: The deployment to prepare.
"""
self._active_deployment = deployment
def _cleanup_run(self) -> None:
"""Cleans up the active run."""
self._active_deployment = None
config: BaseOrchestratorConfig
property
readonly
Returns the BaseOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
BaseOrchestratorConfig |
The configuration. |
get_orchestrator_run_id(self)
Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Source code in zenml/orchestrators/base_orchestrator.py
@abstractmethod
def get_orchestrator_run_id(self) -> str:
"""Returns the run id of the active orchestrator run.
Important: This needs to be a unique ID and return the same value for
all steps of a pipeline run.
Returns:
The orchestrator run id.
"""
prepare_or_run_pipeline(self, deployment, stack, environment)
The method needs to be implemented by the respective orchestrator.
Depending on the type of orchestrator you'll have to perform slightly different operations.
Simple Case:
The Steps are run directly from within the same environment in which
the orchestrator code is executed. In this case you will need to
deal with implementation-specific runtime configurations (like the
schedule) and then iterate through the steps and finally call
self.run_step(...)
to execute each step.
Advanced Case:
Most orchestrators will not run the steps directly. Instead, they build some intermediate representation of the pipeline that is then used to create and run the pipeline and its steps on the target environment. For such orchestrators this method will have to build this representation and deploy it.
Regardless of the implementation details, the orchestrator will need
to run each step in the target environment. For this the
self.run_step(...)
method should be used.
The easiest way to make this work is by using an entrypoint
configuration to run single steps (zenml.entrypoints.step_entrypoint_configuration.StepEntrypointConfiguration
)
or entire pipelines (zenml.entrypoints.pipeline_entrypoint_configuration.PipelineEntrypointConfiguration
).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. These don't need to be set if running locally. |
required |
Returns:
Type | Description |
---|---|
Any |
The optional return value from this method will be returned by the
|
Source code in zenml/orchestrators/base_orchestrator.py
@abstractmethod
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""The method needs to be implemented by the respective orchestrator.
Depending on the type of orchestrator you'll have to perform slightly
different operations.
Simple Case:
------------
The Steps are run directly from within the same environment in which
the orchestrator code is executed. In this case you will need to
deal with implementation-specific runtime configurations (like the
schedule) and then iterate through the steps and finally call
`self.run_step(...)` to execute each step.
Advanced Case:
--------------
Most orchestrators will not run the steps directly. Instead, they
build some intermediate representation of the pipeline that is then
used to create and run the pipeline and its steps on the target
environment. For such orchestrators this method will have to build
this representation and deploy it.
Regardless of the implementation details, the orchestrator will need
to run each step in the target environment. For this the
`self.run_step(...)` method should be used.
The easiest way to make this work is by using an entrypoint
configuration to run single steps (`zenml.entrypoints.step_entrypoint_configuration.StepEntrypointConfiguration`)
or entire pipelines (`zenml.entrypoints.pipeline_entrypoint_configuration.PipelineEntrypointConfiguration`).
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment. These don't need to be set if running locally.
Returns:
The optional return value from this method will be returned by the
`pipeline_instance.run()` call when someone is running a pipeline.
"""
requires_resources_in_orchestration_environment(step)
staticmethod
Checks if the orchestrator should run this step on special resources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
Step |
The step that will be checked. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the step requires special resources in the orchestration environment, False otherwise. |
Source code in zenml/orchestrators/base_orchestrator.py
@staticmethod
def requires_resources_in_orchestration_environment(
step: "Step",
) -> bool:
"""Checks if the orchestrator should run this step on special resources.
Args:
step: The step that will be checked.
Returns:
True if the step requires special resources in the orchestration
environment, False otherwise.
"""
# If the step requires custom resources and doesn't run with a step
# operator, it would need these requirements in the orchestrator
# environment
if step.config.step_operator:
return False
return not step.config.resource_settings.empty
run(self, deployment, stack)
Runs a pipeline on a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment. |
required |
stack |
Stack |
The stack on which to run the pipeline. |
required |
Returns:
Type | Description |
---|---|
Any |
Orchestrator-specific return value. |
Source code in zenml/orchestrators/base_orchestrator.py
def run(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
) -> Any:
"""Runs a pipeline on a stack.
Args:
deployment: The pipeline deployment.
stack: The stack on which to run the pipeline.
Returns:
Orchestrator-specific return value.
"""
self._prepare_run(deployment=deployment)
environment = get_config_environment_vars(deployment=deployment)
try:
result = self.prepare_or_run_pipeline(
deployment=deployment, stack=stack, environment=environment
)
finally:
self._cleanup_run()
return result
run_step(self, step)
Runs the given step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
Step |
The step to run. |
required |
Source code in zenml/orchestrators/base_orchestrator.py
def run_step(self, step: "Step") -> None:
"""Runs the given step.
Args:
step: The step to run.
"""
assert self._active_deployment
launcher = StepLauncher(
deployment=self._active_deployment,
step=step,
orchestrator_run_id=self.get_orchestrator_run_id(),
)
launcher.launch()
BaseOrchestratorConfig (StackComponentConfig)
Base orchestrator config.
Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestratorConfig(StackComponentConfig):
"""Base orchestrator config."""
@model_validator(mode="before")
@classmethod
@before_validator_handler
def _deprecations(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate and/or remove deprecated fields.
Args:
data: The values to validate.
Returns:
The validated values.
"""
if "custom_docker_base_image_name" in data:
image_name = data.pop("custom_docker_base_image_name", None)
if image_name:
logger.warning(
"The 'custom_docker_base_image_name' field has been "
"deprecated. To use a custom base container image with your "
"orchestrators, please use the DockerSettings in your "
"pipeline (see https://docs.zenml.io/how-to/customize-docker-builds)."
)
return data
@property
def is_synchronous(self) -> bool:
"""Whether the orchestrator runs synchronous or not.
Returns:
Whether the orchestrator runs synchronous or not.
"""
return False
@property
def is_schedulable(self) -> bool:
"""Whether the orchestrator is schedulable or not.
Returns:
Whether the orchestrator is schedulable or not.
"""
return False
is_schedulable: bool
property
readonly
Whether the orchestrator is schedulable or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator is schedulable or not. |
is_synchronous: bool
property
readonly
Whether the orchestrator runs synchronous or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator runs synchronous or not. |
BaseOrchestratorFlavor (Flavor)
Base orchestrator flavor class.
Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestratorFlavor(Flavor):
"""Base orchestrator flavor class."""
@property
def type(self) -> StackComponentType:
"""Returns the flavor type.
Returns:
The flavor type.
"""
return StackComponentType.ORCHESTRATOR
@property
def config_class(self) -> Type[BaseOrchestratorConfig]:
"""Config class for the base orchestrator flavor.
Returns:
The config class.
"""
return BaseOrchestratorConfig
@property
@abstractmethod
def implementation_class(self) -> Type["BaseOrchestrator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
config_class: Type[zenml.orchestrators.base_orchestrator.BaseOrchestratorConfig]
property
readonly
Config class for the base orchestrator flavor.
Returns:
Type | Description |
---|---|
Type[zenml.orchestrators.base_orchestrator.BaseOrchestratorConfig] |
The config class. |
implementation_class: Type[BaseOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[BaseOrchestrator] |
The implementation class. |
type: StackComponentType
property
readonly
Returns the flavor type.
Returns:
Type | Description |
---|---|
StackComponentType |
The flavor type. |
cache_utils
Utilities for caching.
generate_cache_key(step, input_artifact_ids, artifact_store, workspace_id)
Generates a cache key for a step run.
If the cache key is the same for two step runs, we conclude that the step runs are identical and can be cached.
The cache key is a MD5 hash of: - the workspace ID, - the artifact store ID and path, - the source code that defines the step, - the parameters of the step, - the names and IDs of the input artifacts of the step, - the names and source codes of the output artifacts of the step, - the source codes of the output materializers of the step. - additional custom caching parameters of the step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
Step |
The step to generate the cache key for. |
required |
input_artifact_ids |
Dict[str, UUID] |
The input artifact IDs for the step. |
required |
artifact_store |
BaseArtifactStore |
The artifact store of the active stack. |
required |
workspace_id |
UUID |
The ID of the active workspace. |
required |
Returns:
Type | Description |
---|---|
str |
A cache key. |
Source code in zenml/orchestrators/cache_utils.py
def generate_cache_key(
step: "Step",
input_artifact_ids: Dict[str, "UUID"],
artifact_store: "BaseArtifactStore",
workspace_id: "UUID",
) -> str:
"""Generates a cache key for a step run.
If the cache key is the same for two step runs, we conclude that the step
runs are identical and can be cached.
The cache key is a MD5 hash of:
- the workspace ID,
- the artifact store ID and path,
- the source code that defines the step,
- the parameters of the step,
- the names and IDs of the input artifacts of the step,
- the names and source codes of the output artifacts of the step,
- the source codes of the output materializers of the step.
- additional custom caching parameters of the step.
Args:
step: The step to generate the cache key for.
input_artifact_ids: The input artifact IDs for the step.
artifact_store: The artifact store of the active stack.
workspace_id: The ID of the active workspace.
Returns:
A cache key.
"""
hash_ = hashlib.md5() # nosec
# Workspace ID
hash_.update(workspace_id.bytes)
# Artifact store ID and path
hash_.update(artifact_store.id.bytes)
hash_.update(artifact_store.path.encode())
if artifact_store.custom_cache_key:
hash_.update(artifact_store.custom_cache_key)
# Step source. This currently only uses the string representation of the
# source (e.g. my_module.step_class) instead of the full source to keep
# the caching behavior of previous versions and to not invalidate caching
# when committing some unrelated files
hash_.update(step.spec.source.import_path.encode())
# Step parameters
for key, value in sorted(step.config.parameters.items()):
hash_.update(key.encode())
hash_.update(str(value).encode())
# Input artifacts
for name, artifact_version_id in input_artifact_ids.items():
hash_.update(name.encode())
hash_.update(artifact_version_id.bytes)
# Output artifacts and materializers
for name, output in step.config.outputs.items():
hash_.update(name.encode())
for source in output.materializer_source:
hash_.update(source.import_path.encode())
# Custom caching parameters
for key, value in sorted(step.config.caching_parameters.items()):
hash_.update(key.encode())
hash_.update(str(value).encode())
return hash_.hexdigest()
get_cached_step_run(cache_key)
If a given step can be cached, get the corresponding existing step run.
A step run can be cached if there is an existing step run in the same workspace which has the same cache key and was successfully executed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cache_key |
str |
The cache key of the step. |
required |
Returns:
Type | Description |
---|---|
Optional[StepRunResponse] |
The existing step run if the step can be cached, otherwise None. |
Source code in zenml/orchestrators/cache_utils.py
def get_cached_step_run(cache_key: str) -> Optional["StepRunResponse"]:
"""If a given step can be cached, get the corresponding existing step run.
A step run can be cached if there is an existing step run in the same
workspace which has the same cache key and was successfully executed.
Args:
cache_key: The cache key of the step.
Returns:
The existing step run if the step can be cached, otherwise None.
"""
client = Client()
cache_candidates = client.list_run_steps(
workspace_id=client.active_workspace.id,
cache_key=cache_key,
status=ExecutionStatus.COMPLETED,
sort_by=f"{SorterOps.DESCENDING}:created",
size=1,
).items
if cache_candidates:
return cache_candidates[0]
return None
containerized_orchestrator
Containerized orchestrator class.
ContainerizedOrchestrator (BaseOrchestrator, ABC)
Base class for containerized orchestrators.
Source code in zenml/orchestrators/containerized_orchestrator.py
class ContainerizedOrchestrator(BaseOrchestrator, ABC):
"""Base class for containerized orchestrators."""
@staticmethod
def get_image(
deployment: "PipelineDeploymentResponse",
step_name: Optional[str] = None,
) -> str:
"""Gets the Docker image for the pipeline/a step.
Args:
deployment: The deployment from which to get the image.
step_name: Pipeline step name for which to get the image. If not
given the generic pipeline image will be returned.
Raises:
RuntimeError: If the deployment does not have an associated build.
Returns:
The image name or digest.
"""
if not deployment.build:
raise RuntimeError(
f"Missing build for deployment {deployment.id}. This is "
"probably because the build was manually deleted."
)
return deployment.build.get_image(
component_key=ORCHESTRATOR_DOCKER_IMAGE_KEY, step=step_name
)
def get_docker_builds(
self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
pipeline_settings = deployment.pipeline_configuration.docker_settings
included_pipeline_build = False
builds = []
for name, step in deployment.step_configurations.items():
step_settings = step.config.docker_settings
if step_settings != pipeline_settings:
build = BuildConfiguration(
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
settings=step_settings,
step_name=name,
)
builds.append(build)
elif not included_pipeline_build:
pipeline_build = BuildConfiguration(
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
settings=pipeline_settings,
)
builds.append(pipeline_build)
included_pipeline_build = True
return builds
get_docker_builds(self, deployment)
Gets the Docker builds required for the component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentBase |
The pipeline deployment for which to get the builds. |
required |
Returns:
Type | Description |
---|---|
List[BuildConfiguration] |
The required Docker builds. |
Source code in zenml/orchestrators/containerized_orchestrator.py
def get_docker_builds(
self, deployment: "PipelineDeploymentBase"
) -> List["BuildConfiguration"]:
"""Gets the Docker builds required for the component.
Args:
deployment: The pipeline deployment for which to get the builds.
Returns:
The required Docker builds.
"""
pipeline_settings = deployment.pipeline_configuration.docker_settings
included_pipeline_build = False
builds = []
for name, step in deployment.step_configurations.items():
step_settings = step.config.docker_settings
if step_settings != pipeline_settings:
build = BuildConfiguration(
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
settings=step_settings,
step_name=name,
)
builds.append(build)
elif not included_pipeline_build:
pipeline_build = BuildConfiguration(
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
settings=pipeline_settings,
)
builds.append(pipeline_build)
included_pipeline_build = True
return builds
get_image(deployment, step_name=None)
staticmethod
Gets the Docker image for the pipeline/a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The deployment from which to get the image. |
required |
step_name |
Optional[str] |
Pipeline step name for which to get the image. If not given the generic pipeline image will be returned. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the deployment does not have an associated build. |
Returns:
Type | Description |
---|---|
str |
The image name or digest. |
Source code in zenml/orchestrators/containerized_orchestrator.py
@staticmethod
def get_image(
deployment: "PipelineDeploymentResponse",
step_name: Optional[str] = None,
) -> str:
"""Gets the Docker image for the pipeline/a step.
Args:
deployment: The deployment from which to get the image.
step_name: Pipeline step name for which to get the image. If not
given the generic pipeline image will be returned.
Raises:
RuntimeError: If the deployment does not have an associated build.
Returns:
The image name or digest.
"""
if not deployment.build:
raise RuntimeError(
f"Missing build for deployment {deployment.id}. This is "
"probably because the build was manually deleted."
)
return deployment.build.get_image(
component_key=ORCHESTRATOR_DOCKER_IMAGE_KEY, step=step_name
)
dag_runner
DAG (Directed Acyclic Graph) Runners.
NodeStatus (Enum)
Status of the execution of a node.
Source code in zenml/orchestrators/dag_runner.py
class NodeStatus(Enum):
"""Status of the execution of a node."""
WAITING = "Waiting"
RUNNING = "Running"
COMPLETED = "Completed"
ThreadedDagRunner
Multi-threaded DAG Runner.
This class expects a DAG of strings in adjacency list representation, as
well as a custom run_fn
as input, then calls run_fn(node)
for each
string node in the DAG.
Steps that can be executed in parallel will be started in separate threads.
Source code in zenml/orchestrators/dag_runner.py
class ThreadedDagRunner:
"""Multi-threaded DAG Runner.
This class expects a DAG of strings in adjacency list representation, as
well as a custom `run_fn` as input, then calls `run_fn(node)` for each
string node in the DAG.
Steps that can be executed in parallel will be started in separate threads.
"""
def __init__(
self,
dag: Dict[str, List[str]],
run_fn: Callable[[str], Any],
parallel_node_startup_waiting_period: float = 0.0,
) -> None:
"""Define attributes and initialize all nodes in waiting state.
Args:
dag: Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
`dag={2: [1], 3: [1], 4: [2, 3]}`
run_fn: A function `run_fn(node)` that runs a single node
parallel_node_startup_waiting_period: Delay in seconds to wait in
between starting parallel nodes.
"""
self.parallel_node_startup_waiting_period = (
parallel_node_startup_waiting_period
)
self.dag = dag
self.reversed_dag = reverse_dag(dag)
self.run_fn = run_fn
self.nodes = dag.keys()
self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
self._lock = threading.Lock()
def _can_run(self, node: str) -> bool:
"""Determine whether a node is ready to be run.
This is the case if the node has not run yet and all of its upstream
node have already completed.
Args:
node: The node.
Returns:
True if the node can run else False.
"""
# Check that node has not run yet.
if not self.node_states[node] == NodeStatus.WAITING:
return False
# Check that all upstream nodes of this node have already completed.
for upstream_node in self.dag[node]:
if not self.node_states[upstream_node] == NodeStatus.COMPLETED:
return False
return True
def _run_node(self, node: str) -> None:
"""Run a single node.
Calls the user-defined run_fn, then calls `self._finish_node`.
Args:
node: The node.
"""
self.run_fn(node)
self._finish_node(node)
def _run_node_in_thread(self, node: str) -> threading.Thread:
"""Run a single node in a separate thread.
First updates the node status to running.
Then calls self._run_node() in a new thread and returns the thread.
Args:
node: The node.
Returns:
The thread in which the node was run.
"""
# Update node status to running.
assert self.node_states[node] == NodeStatus.WAITING
with self._lock:
self.node_states[node] = NodeStatus.RUNNING
# Run node in new thread.
thread = threading.Thread(target=self._run_node, args=(node,))
thread.start()
return thread
def _finish_node(self, node: str) -> None:
"""Finish a node run.
First updates the node status to completed.
Then starts all other nodes that can now be run and waits for them.
Args:
node: The node.
"""
# Update node status to completed.
assert self.node_states[node] == NodeStatus.RUNNING
with self._lock:
self.node_states[node] = NodeStatus.COMPLETED
# Run downstream nodes.
threads: List[threading.Thread] = []
for downstram_node in self.reversed_dag[node]:
if self._can_run(downstram_node):
if threads and self.parallel_node_startup_waiting_period > 0:
time.sleep(self.parallel_node_startup_waiting_period)
thread = self._run_node_in_thread(downstram_node)
threads.append(thread)
# Wait for all downstream nodes to complete.
for thread in threads:
thread.join()
def run(self) -> None:
"""Call `self.run_fn` on all nodes in `self.dag`.
The order of execution is determined using topological sort.
Each node is run in a separate thread to enable parallelism.
"""
# Run all nodes that can be started immediately.
# These will, in turn, start other nodes once all of their respective
# upstream nodes have completed.
threads: List[threading.Thread] = []
for node in self.nodes:
if self._can_run(node):
if threads and self.parallel_node_startup_waiting_period > 0:
time.sleep(self.parallel_node_startup_waiting_period)
thread = self._run_node_in_thread(node)
threads.append(thread)
# Wait till all nodes have completed.
for thread in threads:
thread.join()
# Make sure all nodes were run, otherwise print a warning.
for node in self.nodes:
if self.node_states[node] == NodeStatus.WAITING:
upstream_nodes = self.dag[node]
logger.warning(
f"Node `{node}` was never run, because it was still"
f" waiting for the following nodes: `{upstream_nodes}`."
)
__init__(self, dag, run_fn, parallel_node_startup_waiting_period=0.0)
special
Define attributes and initialize all nodes in waiting state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag |
Dict[str, List[str]] |
Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
|
required |
run_fn |
Callable[[str], Any] |
A function |
required |
parallel_node_startup_waiting_period |
float |
Delay in seconds to wait in between starting parallel nodes. |
0.0 |
Source code in zenml/orchestrators/dag_runner.py
def __init__(
self,
dag: Dict[str, List[str]],
run_fn: Callable[[str], Any],
parallel_node_startup_waiting_period: float = 0.0,
) -> None:
"""Define attributes and initialize all nodes in waiting state.
Args:
dag: Adjacency list representation of a DAG.
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
`dag={2: [1], 3: [1], 4: [2, 3]}`
run_fn: A function `run_fn(node)` that runs a single node
parallel_node_startup_waiting_period: Delay in seconds to wait in
between starting parallel nodes.
"""
self.parallel_node_startup_waiting_period = (
parallel_node_startup_waiting_period
)
self.dag = dag
self.reversed_dag = reverse_dag(dag)
self.run_fn = run_fn
self.nodes = dag.keys()
self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
self._lock = threading.Lock()
run(self)
Call self.run_fn
on all nodes in self.dag
.
The order of execution is determined using topological sort. Each node is run in a separate thread to enable parallelism.
Source code in zenml/orchestrators/dag_runner.py
def run(self) -> None:
"""Call `self.run_fn` on all nodes in `self.dag`.
The order of execution is determined using topological sort.
Each node is run in a separate thread to enable parallelism.
"""
# Run all nodes that can be started immediately.
# These will, in turn, start other nodes once all of their respective
# upstream nodes have completed.
threads: List[threading.Thread] = []
for node in self.nodes:
if self._can_run(node):
if threads and self.parallel_node_startup_waiting_period > 0:
time.sleep(self.parallel_node_startup_waiting_period)
thread = self._run_node_in_thread(node)
threads.append(thread)
# Wait till all nodes have completed.
for thread in threads:
thread.join()
# Make sure all nodes were run, otherwise print a warning.
for node in self.nodes:
if self.node_states[node] == NodeStatus.WAITING:
upstream_nodes = self.dag[node]
logger.warning(
f"Node `{node}` was never run, because it was still"
f" waiting for the following nodes: `{upstream_nodes}`."
)
reverse_dag(dag)
Reverse a DAG.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag |
Dict[str, List[str]] |
Adjacency list representation of a DAG. |
required |
Returns:
Type | Description |
---|---|
Dict[str, List[str]] |
Adjacency list representation of the reversed DAG. |
Source code in zenml/orchestrators/dag_runner.py
def reverse_dag(dag: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Reverse a DAG.
Args:
dag: Adjacency list representation of a DAG.
Returns:
Adjacency list representation of the reversed DAG.
"""
reversed_dag = defaultdict(list)
# Reverse all edges in the graph.
for node, upstream_nodes in dag.items():
for upstream_node in upstream_nodes:
reversed_dag[upstream_node].append(node)
# Add nodes without incoming edges back in.
for node in dag:
if node not in reversed_dag:
reversed_dag[node] = []
return reversed_dag
input_utils
Utilities for inputs.
resolve_step_inputs(step, run_id)
Resolves inputs for the current step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
Step |
The step for which to resolve the inputs. |
required |
run_id |
UUID |
The ID of the current pipeline run. |
required |
Exceptions:
Type | Description |
---|---|
InputResolutionError |
If input resolving failed due to a missing step or output. |
ValueError |
If object from model version passed into a step cannot be resolved in runtime due to missing object. |
Returns:
Type | Description |
---|---|
Tuple[Dict[str, ArtifactVersionResponse], List[uuid.UUID]] |
The IDs of the input artifact versions and the IDs of parent steps of the current step. |
Source code in zenml/orchestrators/input_utils.py
def resolve_step_inputs(
step: "Step",
run_id: UUID,
) -> Tuple[Dict[str, "ArtifactVersionResponse"], List[UUID]]:
"""Resolves inputs for the current step.
Args:
step: The step for which to resolve the inputs.
run_id: The ID of the current pipeline run.
Raises:
InputResolutionError: If input resolving failed due to a missing
step or output.
ValueError: If object from model version passed into a step cannot be
resolved in runtime due to missing object.
Returns:
The IDs of the input artifact versions and the IDs of parent steps of
the current step.
"""
from zenml.models import ArtifactVersionResponse, RunMetadataResponse
current_run_steps = {
run_step.name: run_step
for run_step in pagination_utils.depaginate(
Client().list_run_steps, pipeline_run_id=run_id
)
}
input_artifacts: Dict[str, "ArtifactVersionResponse"] = {}
for name, input_ in step.spec.inputs.items():
try:
step_run = current_run_steps[input_.step_name]
except KeyError:
raise InputResolutionError(
f"No step `{input_.step_name}` found in current run."
)
try:
artifact = step_run.outputs[input_.output_name]
except KeyError:
raise InputResolutionError(
f"No output `{input_.output_name}` found for step "
f"`{input_.step_name}`."
)
input_artifacts[name] = artifact
for (
name,
external_artifact,
) in step.config.external_input_artifacts.items():
artifact_version_id = external_artifact.get_artifact_version_id()
input_artifacts[name] = Client().get_artifact_version(
artifact_version_id
)
for name, config_ in step.config.model_artifacts_or_metadata.items():
issue_found = False
try:
if config_.metadata_name is None and config_.artifact_name:
if artifact_ := config_.model.get_artifact(
config_.artifact_name, config_.artifact_version
):
input_artifacts[name] = artifact_
else:
issue_found = True
elif config_.artifact_name is None and config_.metadata_name:
# metadata values should go directly in parameters, as primitive types
step.config.parameters[name] = config_.model.run_metadata[
config_.metadata_name
].value
elif config_.metadata_name and config_.artifact_name:
# metadata values should go directly in parameters, as primitive types
if artifact_ := config_.model.get_artifact(
config_.artifact_name, config_.artifact_version
):
step.config.parameters[name] = artifact_.run_metadata[
config_.metadata_name
].value
else:
issue_found = True
else:
issue_found = True
except KeyError:
issue_found = True
if issue_found:
raise ValueError(
"Cannot fetch requested information from model "
f"`{config_.model.name}` version "
f"`{config_.model.version}` given artifact "
f"`{config_.artifact_name}`, artifact version "
f"`{config_.artifact_version}`, and metadata "
f"key `{config_.metadata_name}` passed into "
f"the step `{step.config.name}`."
)
for name, cll_ in step.config.client_lazy_loaders.items():
value_ = cll_.evaluate()
if isinstance(value_, ArtifactVersionResponse):
input_artifacts[name] = value_
elif isinstance(value_, RunMetadataResponse):
step.config.parameters[name] = value_.value
else:
step.config.parameters[name] = value_
parent_step_ids = [
current_run_steps[upstream_step].id
for upstream_step in step.spec.upstream_steps
]
return input_artifacts, parent_step_ids
local
special
Initialization for the local orchestrator.
local_orchestrator
Implementation of the ZenML local orchestrator.
LocalOrchestrator (BaseOrchestrator)
Orchestrator responsible for running pipelines locally.
This orchestrator does not allow for concurrent execution of steps and also does not support running on a schedule.
Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines locally.
This orchestrator does not allow for concurrent execution of steps and also
does not support running on a schedule.
"""
_orchestrator_run_id: Optional[str] = None
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Iterates through all steps and executes them sequentially.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack on which the pipeline is deployed.
environment: Environment variables to set in the orchestration
environment.
"""
if deployment.schedule:
logger.warning(
"Local Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
self._orchestrator_run_id = str(uuid4())
start_time = time.time()
# Run each step
for step_name, step in deployment.step_configurations.items():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the local "
"orchestrator, ignoring resource configuration for "
"step %s.",
step_name,
)
self.run_step(
step=step,
)
run_duration = time.time() - start_time
logger.info(
"Pipeline run has finished in `%s`.",
string_utils.get_human_readable_time(run_duration),
)
self._orchestrator_run_id = None
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If no run id exists. This happens when this method
gets called while the orchestrator is not running a pipeline.
Returns:
The orchestrator run id.
"""
if not self._orchestrator_run_id:
raise RuntimeError("No run id set.")
return self._orchestrator_run_id
get_orchestrator_run_id(self)
Returns the active orchestrator run id.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no run id exists. This happens when this method gets called while the orchestrator is not running a pipeline. |
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Source code in zenml/orchestrators/local/local_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If no run id exists. This happens when this method
gets called while the orchestrator is not running a pipeline.
Returns:
The orchestrator run id.
"""
if not self._orchestrator_run_id:
raise RuntimeError("No run id set.")
return self._orchestrator_run_id
prepare_or_run_pipeline(self, deployment, stack, environment)
Iterates through all steps and executes them sequentially.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack on which the pipeline is deployed. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Source code in zenml/orchestrators/local/local_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Iterates through all steps and executes them sequentially.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack on which the pipeline is deployed.
environment: Environment variables to set in the orchestration
environment.
"""
if deployment.schedule:
logger.warning(
"Local Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
self._orchestrator_run_id = str(uuid4())
start_time = time.time()
# Run each step
for step_name, step in deployment.step_configurations.items():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the local "
"orchestrator, ignoring resource configuration for "
"step %s.",
step_name,
)
self.run_step(
step=step,
)
run_duration = time.time() - start_time
logger.info(
"Pipeline run has finished in `%s`.",
string_utils.get_human_readable_time(run_duration),
)
self._orchestrator_run_id = None
LocalOrchestratorConfig (BaseOrchestratorConfig)
Local orchestrator config.
Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestratorConfig(BaseOrchestratorConfig):
"""Local orchestrator config."""
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
Returns:
True if this config is for a local component, False otherwise.
"""
return True
@property
def is_synchronous(self) -> bool:
"""Whether the orchestrator runs synchronous or not.
Returns:
Whether the orchestrator runs synchronous or not.
"""
return True
is_local: bool
property
readonly
Checks if this stack component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
is_synchronous: bool
property
readonly
Whether the orchestrator runs synchronous or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator runs synchronous or not. |
LocalOrchestratorFlavor (BaseOrchestratorFlavor)
Class for the LocalOrchestratorFlavor
.
Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestratorFlavor(BaseOrchestratorFlavor):
"""Class for the `LocalOrchestratorFlavor`."""
@property
def name(self) -> str:
"""The flavor name.
Returns:
The flavor name.
"""
return "local"
@property
def docs_url(self) -> Optional[str]:
"""A URL to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A URL to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A URL to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/local.png"
@property
def config_class(self) -> Type[BaseOrchestratorConfig]:
"""Config class for the base orchestrator flavor.
Returns:
The config class.
"""
return LocalOrchestratorConfig
@property
def implementation_class(self) -> Type[LocalOrchestrator]:
"""Implementation class for this flavor.
Returns:
The implementation class for this flavor.
"""
return LocalOrchestrator
config_class: Type[zenml.orchestrators.base_orchestrator.BaseOrchestratorConfig]
property
readonly
Config class for the base orchestrator flavor.
Returns:
Type | Description |
---|---|
Type[zenml.orchestrators.base_orchestrator.BaseOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A URL to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[zenml.orchestrators.local.local_orchestrator.LocalOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[zenml.orchestrators.local.local_orchestrator.LocalOrchestrator] |
The implementation class for this flavor. |
logo_url: str
property
readonly
A URL to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
The flavor name.
Returns:
Type | Description |
---|---|
str |
The flavor name. |
sdk_docs_url: Optional[str]
property
readonly
A URL to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
local_docker
special
Initialization for the local Docker orchestrator.
local_docker_orchestrator
Implementation of the ZenML local Docker orchestrator.
LocalDockerOrchestrator (ContainerizedOrchestrator)
Orchestrator responsible for running pipelines locally using Docker.
This orchestrator does not allow for concurrent execution of steps and also does not support running on a schedule.
Source code in zenml/orchestrators/local_docker/local_docker_orchestrator.py
class LocalDockerOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines locally using Docker.
This orchestrator does not allow for concurrent execution of steps and also
does not support running on a schedule.
"""
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Local Docker orchestrator.
Returns:
The settings class.
"""
return LocalDockerOrchestratorSettings
@property
def validator(self) -> Optional[StackValidator]:
"""Ensures there is an image builder in the stack.
Returns:
A `StackValidator` instance.
"""
return StackValidator(
required_components={StackComponentType.IMAGE_BUILDER}
)
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If the environment variable specifying the run id
is not set.
Returns:
The orchestrator run id.
"""
try:
return os.environ[ENV_ZENML_DOCKER_ORCHESTRATOR_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_DOCKER_ORCHESTRATOR_RUN_ID}."
)
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Sequentially runs all pipeline steps in local Docker containers.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
RuntimeError: If a step fails.
"""
if deployment.schedule:
logger.warning(
"Local Docker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
docker_client = docker_utils._try_get_docker_client_from_env()
entrypoint = StepEntrypointConfiguration.get_entrypoint_command()
# Add the local stores path as a volume mount
stack.check_local_paths()
local_stores_path = GlobalConfiguration().local_stores_path
volumes = {
local_stores_path: {
"bind": local_stores_path,
"mode": "rw",
}
}
orchestrator_run_id = str(uuid4())
environment[ENV_ZENML_DOCKER_ORCHESTRATOR_RUN_ID] = orchestrator_run_id
environment[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
start_time = time.time()
# Run each step
for step_name, step in deployment.step_configurations.items():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the local "
"Docker orchestrator, ignoring resource configuration for "
"step %s.",
step_name,
)
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
settings = cast(
LocalDockerOrchestratorSettings,
self.get_settings(step),
)
image = self.get_image(deployment=deployment, step_name=step_name)
user = None
if sys.platform != "win32":
user = os.getuid()
logger.info("Running step `%s` in Docker:", step_name)
run_args = copy.deepcopy(settings.run_args)
docker_environment = run_args.pop("environment", {})
docker_environment.update(environment)
docker_volumes = run_args.pop("volumes", {})
docker_volumes.update(volumes)
extra_hosts = run_args.pop("extra_hosts", {})
extra_hosts["host.docker.internal"] = "host-gateway"
try:
logs = docker_client.containers.run(
image=image,
entrypoint=entrypoint,
command=arguments,
user=user,
volumes=docker_volumes,
environment=docker_environment,
stream=True,
extra_hosts=extra_hosts,
**run_args,
)
for line in logs:
logger.info(line.strip().decode())
except ContainerError as e:
error_message = e.stderr.decode()
raise RuntimeError(error_message)
run_duration = time.time() - start_time
logger.info(
"Pipeline run has finished in `%s`.",
string_utils.get_human_readable_time(run_duration),
)
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Local Docker orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Ensures there is an image builder in the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
get_orchestrator_run_id(self)
Returns the active orchestrator run id.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the environment variable specifying the run id is not set. |
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Source code in zenml/orchestrators/local_docker/local_docker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If the environment variable specifying the run id
is not set.
Returns:
The orchestrator run id.
"""
try:
return os.environ[ENV_ZENML_DOCKER_ORCHESTRATOR_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_DOCKER_ORCHESTRATOR_RUN_ID}."
)
prepare_or_run_pipeline(self, deployment, stack, environment)
Sequentially runs all pipeline steps in local Docker containers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If a step fails. |
Source code in zenml/orchestrators/local_docker/local_docker_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Sequentially runs all pipeline steps in local Docker containers.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
RuntimeError: If a step fails.
"""
if deployment.schedule:
logger.warning(
"Local Docker Orchestrator currently does not support the "
"use of schedules. The `schedule` will be ignored "
"and the pipeline will be run immediately."
)
docker_client = docker_utils._try_get_docker_client_from_env()
entrypoint = StepEntrypointConfiguration.get_entrypoint_command()
# Add the local stores path as a volume mount
stack.check_local_paths()
local_stores_path = GlobalConfiguration().local_stores_path
volumes = {
local_stores_path: {
"bind": local_stores_path,
"mode": "rw",
}
}
orchestrator_run_id = str(uuid4())
environment[ENV_ZENML_DOCKER_ORCHESTRATOR_RUN_ID] = orchestrator_run_id
environment[ENV_ZENML_LOCAL_STORES_PATH] = local_stores_path
start_time = time.time()
# Run each step
for step_name, step in deployment.step_configurations.items():
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not supported for the local "
"Docker orchestrator, ignoring resource configuration for "
"step %s.",
step_name,
)
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
settings = cast(
LocalDockerOrchestratorSettings,
self.get_settings(step),
)
image = self.get_image(deployment=deployment, step_name=step_name)
user = None
if sys.platform != "win32":
user = os.getuid()
logger.info("Running step `%s` in Docker:", step_name)
run_args = copy.deepcopy(settings.run_args)
docker_environment = run_args.pop("environment", {})
docker_environment.update(environment)
docker_volumes = run_args.pop("volumes", {})
docker_volumes.update(volumes)
extra_hosts = run_args.pop("extra_hosts", {})
extra_hosts["host.docker.internal"] = "host-gateway"
try:
logs = docker_client.containers.run(
image=image,
entrypoint=entrypoint,
command=arguments,
user=user,
volumes=docker_volumes,
environment=docker_environment,
stream=True,
extra_hosts=extra_hosts,
**run_args,
)
for line in logs:
logger.info(line.strip().decode())
except ContainerError as e:
error_message = e.stderr.decode()
raise RuntimeError(error_message)
run_duration = time.time() - start_time
logger.info(
"Pipeline run has finished in `%s`.",
string_utils.get_human_readable_time(run_duration),
)
LocalDockerOrchestratorConfig (BaseOrchestratorConfig, LocalDockerOrchestratorSettings)
Local Docker orchestrator config.
Source code in zenml/orchestrators/local_docker/local_docker_orchestrator.py
class LocalDockerOrchestratorConfig(
BaseOrchestratorConfig, LocalDockerOrchestratorSettings
):
"""Local Docker orchestrator config."""
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
Returns:
True if this config is for a local component, False otherwise.
"""
return True
@property
def is_synchronous(self) -> bool:
"""Whether the orchestrator runs synchronous or not.
Returns:
Whether the orchestrator runs synchronous or not.
"""
return True
is_local: bool
property
readonly
Checks if this stack component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
is_synchronous: bool
property
readonly
Whether the orchestrator runs synchronous or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator runs synchronous or not. |
LocalDockerOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the local Docker orchestrator.
Source code in zenml/orchestrators/local_docker/local_docker_orchestrator.py
class LocalDockerOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the local Docker orchestrator."""
@property
def name(self) -> str:
"""Name of the orchestrator flavor.
Returns:
Name of the orchestrator flavor.
"""
return "local_docker"
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/docker.png"
@property
def config_class(self) -> Type[BaseOrchestratorConfig]:
"""Config class for the base orchestrator flavor.
Returns:
The config class.
"""
return LocalDockerOrchestratorConfig
@property
def implementation_class(self) -> Type["LocalDockerOrchestrator"]:
"""Implementation class for this flavor.
Returns:
Implementation class for this flavor.
"""
return LocalDockerOrchestrator
config_class: Type[zenml.orchestrators.base_orchestrator.BaseOrchestratorConfig]
property
readonly
Config class for the base orchestrator flavor.
Returns:
Type | Description |
---|---|
Type[zenml.orchestrators.base_orchestrator.BaseOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[LocalDockerOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[LocalDockerOrchestrator] |
Implementation class for this flavor. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the orchestrator flavor.
Returns:
Type | Description |
---|---|
str |
Name of the orchestrator flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
LocalDockerOrchestratorSettings (BaseSettings)
Local Docker orchestrator settings.
Attributes:
Name | Type | Description |
---|---|---|
run_args |
Dict[str, Any] |
Arguments to pass to the |
Source code in zenml/orchestrators/local_docker/local_docker_orchestrator.py
class LocalDockerOrchestratorSettings(BaseSettings):
"""Local Docker orchestrator settings.
Attributes:
run_args: Arguments to pass to the `docker run` call. (See
https://docker-py.readthedocs.io/en/stable/containers.html for a list
of what can be passed.)
"""
run_args: Dict[str, Any] = {}
output_utils
Utilities for outputs.
generate_artifact_uri(artifact_store, step_run, output_name)
Generates a URI for an output artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_store |
BaseArtifactStore |
The artifact store on which the artifact will be stored. |
required |
step_run |
StepRunResponse |
The step run that created the artifact. |
required |
output_name |
str |
The name of the output in the step run for this artifact. |
required |
Returns:
Type | Description |
---|---|
str |
The URI of the output artifact. |
Source code in zenml/orchestrators/output_utils.py
def generate_artifact_uri(
artifact_store: "BaseArtifactStore",
step_run: "StepRunResponse",
output_name: str,
) -> str:
"""Generates a URI for an output artifact.
Args:
artifact_store: The artifact store on which the artifact will be stored.
step_run: The step run that created the artifact.
output_name: The name of the output in the step run for this artifact.
Returns:
The URI of the output artifact.
"""
for banned_character in ["<", ">", ":", '"', "/", "\\", "|", "?", "*"]:
output_name = output_name.replace(banned_character, "_")
return os.path.join(
artifact_store.path,
step_run.name,
output_name,
str(step_run.id),
str(uuid4())[:8], # add random subfolder to avoid collisions
)
prepare_output_artifact_uris(step_run, stack, step)
Prepares the output artifact URIs to run the current step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run |
StepRunResponse |
The step run for which to prepare the artifact URIs. |
required |
stack |
Stack |
The stack on which the pipeline is running. |
required |
step |
Step |
The step configuration. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If an artifact URI already exists. |
Returns:
Type | Description |
---|---|
Dict[str, str] |
A dictionary mapping output names to artifact URIs. |
Source code in zenml/orchestrators/output_utils.py
def prepare_output_artifact_uris(
step_run: "StepRunResponse", stack: "Stack", step: "Step"
) -> Dict[str, str]:
"""Prepares the output artifact URIs to run the current step.
Args:
step_run: The step run for which to prepare the artifact URIs.
stack: The stack on which the pipeline is running.
step: The step configuration.
Raises:
RuntimeError: If an artifact URI already exists.
Returns:
A dictionary mapping output names to artifact URIs.
"""
artifact_store = stack.artifact_store
output_artifact_uris: Dict[str, str] = {}
for output_name in step.config.outputs.keys():
artifact_uri = generate_artifact_uri(
artifact_store=stack.artifact_store,
step_run=step_run,
output_name=output_name,
)
if artifact_store.exists(artifact_uri):
raise RuntimeError("Artifact already exists")
artifact_store.makedirs(artifact_uri)
output_artifact_uris[output_name] = artifact_uri
return output_artifact_uris
remove_artifact_dirs(artifact_uris)
Removes the artifact directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact_uris |
Sequence[str] |
URIs of the artifacts to remove the directories for. |
required |
Source code in zenml/orchestrators/output_utils.py
def remove_artifact_dirs(artifact_uris: Sequence[str]) -> None:
"""Removes the artifact directories.
Args:
artifact_uris: URIs of the artifacts to remove the directories for.
"""
artifact_store = Client().active_stack.artifact_store
for artifact_uri in artifact_uris:
if artifact_store.isdir(artifact_uri):
artifact_store.rmtree(artifact_uri)
publish_utils
Utilities to publish pipeline and step runs.
get_pipeline_run_status(step_statuses, num_steps)
Gets the pipeline run status for the given step statuses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_statuses |
List[zenml.enums.ExecutionStatus] |
The status of steps in this run. |
required |
num_steps |
int |
The total amount of steps in this run. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
The run status. |
Source code in zenml/orchestrators/publish_utils.py
def get_pipeline_run_status(
step_statuses: List[ExecutionStatus], num_steps: int
) -> ExecutionStatus:
"""Gets the pipeline run status for the given step statuses.
Args:
step_statuses: The status of steps in this run.
num_steps: The total amount of steps in this run.
Returns:
The run status.
"""
if ExecutionStatus.FAILED in step_statuses:
return ExecutionStatus.FAILED
if (
ExecutionStatus.RUNNING in step_statuses
or len(step_statuses) < num_steps
):
return ExecutionStatus.RUNNING
return ExecutionStatus.COMPLETED
publish_failed_pipeline_run(pipeline_run_id)
Publishes a failed pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run_id |
UUID |
The ID of the pipeline run to update. |
required |
Returns:
Type | Description |
---|---|
PipelineRunResponse |
The updated pipeline run. |
Source code in zenml/orchestrators/publish_utils.py
def publish_failed_pipeline_run(
pipeline_run_id: "UUID",
) -> "PipelineRunResponse":
"""Publishes a failed pipeline run.
Args:
pipeline_run_id: The ID of the pipeline run to update.
Returns:
The updated pipeline run.
"""
return Client().zen_store.update_run(
run_id=pipeline_run_id,
run_update=PipelineRunUpdate(
status=ExecutionStatus.FAILED,
end_time=datetime.utcnow(),
),
)
publish_failed_step_run(step_run_id)
Publishes a failed step run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_id |
UUID |
The ID of the step run to update. |
required |
Returns:
Type | Description |
---|---|
StepRunResponse |
The updated step run. |
Source code in zenml/orchestrators/publish_utils.py
def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
"""Publishes a failed step run.
Args:
step_run_id: The ID of the step run to update.
Returns:
The updated step run.
"""
return Client().zen_store.update_run_step(
step_run_id=step_run_id,
step_run_update=StepRunUpdate(
status=ExecutionStatus.FAILED,
end_time=datetime.utcnow(),
),
)
publish_pipeline_run_metadata(pipeline_run_id, pipeline_run_metadata)
Publishes the given pipeline run metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run_id |
UUID |
The ID of the pipeline run. |
required |
pipeline_run_metadata |
Dict[UUID, Dict[str, MetadataType]] |
A dictionary mapping stack component IDs to the metadata they created. |
required |
Source code in zenml/orchestrators/publish_utils.py
def publish_pipeline_run_metadata(
pipeline_run_id: "UUID",
pipeline_run_metadata: Dict["UUID", Dict[str, "MetadataType"]],
) -> None:
"""Publishes the given pipeline run metadata.
Args:
pipeline_run_id: The ID of the pipeline run.
pipeline_run_metadata: A dictionary mapping stack component IDs to the
metadata they created.
"""
client = Client()
for stack_component_id, metadata in pipeline_run_metadata.items():
client.create_run_metadata(
metadata=metadata,
resource_id=pipeline_run_id,
resource_type=MetadataResourceTypes.PIPELINE_RUN,
stack_component_id=stack_component_id,
)
publish_step_run_metadata(step_run_id, step_run_metadata)
Publishes the given step run metadata.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_id |
UUID |
The ID of the step run. |
required |
step_run_metadata |
Dict[UUID, Dict[str, MetadataType]] |
A dictionary mapping stack component IDs to the metadata they created. |
required |
Source code in zenml/orchestrators/publish_utils.py
def publish_step_run_metadata(
step_run_id: "UUID",
step_run_metadata: Dict["UUID", Dict[str, "MetadataType"]],
) -> None:
"""Publishes the given step run metadata.
Args:
step_run_id: The ID of the step run.
step_run_metadata: A dictionary mapping stack component IDs to the
metadata they created.
"""
client = Client()
for stack_component_id, metadata in step_run_metadata.items():
client.create_run_metadata(
metadata=metadata,
resource_id=step_run_id,
resource_type=MetadataResourceTypes.STEP_RUN,
stack_component_id=stack_component_id,
)
publish_successful_step_run(step_run_id, output_artifact_ids)
Publishes a successful step run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_run_id |
UUID |
The ID of the step run to update. |
required |
output_artifact_ids |
Dict[str, UUID] |
The output artifact IDs for the step run. |
required |
Returns:
Type | Description |
---|---|
StepRunResponse |
The updated step run. |
Source code in zenml/orchestrators/publish_utils.py
def publish_successful_step_run(
step_run_id: "UUID", output_artifact_ids: Dict[str, "UUID"]
) -> "StepRunResponse":
"""Publishes a successful step run.
Args:
step_run_id: The ID of the step run to update.
output_artifact_ids: The output artifact IDs for the step run.
Returns:
The updated step run.
"""
return Client().zen_store.update_run_step(
step_run_id=step_run_id,
step_run_update=StepRunUpdate(
status=ExecutionStatus.COMPLETED,
end_time=datetime.utcnow(),
outputs=output_artifact_ids,
),
)
step_launcher
Class to launch (run directly or using a step operator) steps.
StepLauncher
A class responsible for launching a step of a ZenML pipeline.
This class follows these steps to launch and publish a ZenML step:
1. Publish or reuse a PipelineRun
2. Resolve the input artifacts of the step
3. Generate a cache key for the step
4. Check if the step can be cached or not
5. Publish a new StepRun
6. If the step can't be cached, the step will be executed in one of these
two ways depending on its configuration:
- Calling a step operator
to run the step in a different environment
- Calling a step runner
to run the step in the current environment
7. Update the status of the previously published StepRun
8. Update the status of the PipelineRun
Source code in zenml/orchestrators/step_launcher.py
class StepLauncher:
"""A class responsible for launching a step of a ZenML pipeline.
This class follows these steps to launch and publish a ZenML step:
1. Publish or reuse a `PipelineRun`
2. Resolve the input artifacts of the step
3. Generate a cache key for the step
4. Check if the step can be cached or not
5. Publish a new `StepRun`
6. If the step can't be cached, the step will be executed in one of these
two ways depending on its configuration:
- Calling a `step operator` to run the step in a different environment
- Calling a `step runner` to run the step in the current environment
7. Update the status of the previously published `StepRun`
8. Update the status of the `PipelineRun`
"""
def __init__(
self,
deployment: PipelineDeploymentResponse,
step: Step,
orchestrator_run_id: str,
):
"""Initializes the launcher.
Args:
deployment: The pipeline deployment.
step: The step to launch.
orchestrator_run_id: The orchestrator pipeline run id.
Raises:
RuntimeError: If the deployment has no associated stack.
"""
self._deployment = deployment
self._step = step
self._orchestrator_run_id = orchestrator_run_id
if not deployment.stack:
raise RuntimeError(
f"Missing stack for deployment {deployment.id}. This is "
"probably because the stack was manually deleted."
)
self._stack = Stack.from_model(deployment.stack)
self._step_name = step.spec.pipeline_parameter_name
def launch(self) -> None:
"""Launches the step.
Raises:
BaseException: If the step failed to launch, run, or publish.
"""
pipeline_run, run_was_created = self._create_or_reuse_run()
# Enable or disable step logs storage
if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False):
step_logging_enabled = False
else:
step_logging_enabled = orchestrator_utils.is_setting_enabled(
is_enabled_on_step=self._step.config.enable_step_logs,
is_enabled_on_pipeline=self._deployment.pipeline_configuration.enable_step_logs,
)
logs_context = nullcontext()
logs_model = None
if step_logging_enabled:
# Configure the logs
logs_uri = step_logging.prepare_logs_uri(
self._stack.artifact_store,
self._step.config.name,
)
logs_context = step_logging.StepLogsStorageContext(
logs_uri=logs_uri
) # type: ignore[assignment]
logs_model = LogsRequest(
uri=logs_uri,
artifact_store_id=self._stack.artifact_store.id,
)
try:
with logs_context:
if run_was_created:
pipeline_run_metadata = (
self._stack.get_pipeline_run_metadata(
run_id=pipeline_run.id
)
)
publish_utils.publish_pipeline_run_metadata(
pipeline_run_id=pipeline_run.id,
pipeline_run_metadata=pipeline_run_metadata,
)
client = Client()
(
docstring,
source_code,
) = self._get_step_docstring_and_source_code()
code_hash = self._deployment.step_configurations[
self._step_name
].config.caching_parameters.get(STEP_SOURCE_PARAMETER_NAME)
step_run = StepRunRequest(
name=self._step_name,
pipeline_run_id=pipeline_run.id,
deployment=self._deployment.id,
code_hash=code_hash,
status=ExecutionStatus.RUNNING,
docstring=docstring,
source_code=source_code,
start_time=datetime.utcnow(),
user=client.active_user.id,
workspace=client.active_workspace.id,
logs=logs_model,
)
try:
execution_needed, step_run = self._prepare(
step_run=step_run
)
except:
logger.exception(
f"Failed preparing run step `{self._step_name}`."
)
step_run.status = ExecutionStatus.FAILED
step_run.end_time = datetime.utcnow()
raise
finally:
step_run_response = Client().zen_store.create_run_step(
step_run
)
# warm-up and register model version
_step_run = None
model = (
self._deployment.step_configurations[
step_run.name
].config.model
or self._deployment.pipeline_configuration.model
)
if self._deployment.step_configurations[
step_run.name
].config.model:
_step_run = step_run_response
if model:
prep_logs_to_show = (
model._prepare_model_version_before_step_launch(
pipeline_run=pipeline_run,
step_run=_step_run,
return_logs=True,
)
)
if prep_logs_to_show:
logger.info(prep_logs_to_show)
logger.info(f"Step `{self._step_name}` has started.")
if execution_needed:
retries = 0
last_retry = True
max_retries = (
step_run_response.config.retry.max_retries
if step_run_response.config.retry
else 1
)
delay = (
step_run_response.config.retry.delay
if step_run_response.config.retry
else 0
)
backoff = (
step_run_response.config.retry.backoff
if step_run_response.config.retry
else 1
)
while retries < max_retries:
last_retry = retries == max_retries - 1
try:
# here pass a forced save_to_file callable to be
# used as a dump function to use before starting
# the external jobs in step operators
if isinstance(
logs_context,
step_logging.StepLogsStorageContext,
):
force_write_logs = partial(
logs_context.storage.save_to_file,
force=True,
)
else:
def _bypass() -> None:
return None
force_write_logs = _bypass
self._run_step(
pipeline_run=pipeline_run,
step_run=step_run_response,
last_retry=last_retry,
force_write_logs=force_write_logs,
)
logger.info(
f"Step `{self._step_name}` completed successfully."
)
break
except BaseException as e: # noqa: E722
retries += 1
if retries < max_retries:
logger.error(
f"Failed to run step `{self._step_name}`. Retrying..."
)
logger.exception(e)
logger.info(
f"Sleeping for {delay} seconds before retrying."
)
time.sleep(delay)
delay *= backoff
else:
logger.error(
f"Failed to run step `{self._step_name}` after {max_retries} retries. Exiting."
)
logger.exception(e)
publish_utils.publish_failed_step_run(
step_run_response.id
)
raise
else:
orchestrator_utils._link_cached_artifacts_to_model(
model_from_context=model,
step_run=step_run,
step_source=self._step.spec.source,
)
if model:
orchestrator_utils._link_pipeline_run_to_model_from_context(
pipeline_run_id=step_run.pipeline_run_id,
model=model,
)
except: # noqa: E722
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
publish_utils.publish_failed_pipeline_run(pipeline_run.id)
raise
def _get_step_docstring_and_source_code(self) -> Tuple[Optional[str], str]:
"""Gets the docstring and source code of the step.
If any of the two is longer than 1000 characters, it will be truncated.
Returns:
The docstring and source code of the step.
"""
from zenml.steps.base_step import BaseStep
step_instance = BaseStep.load_from_source(self._step.spec.source)
docstring = step_instance.docstring
if docstring and len(docstring) > TEXT_FIELD_MAX_LENGTH:
docstring = docstring[: (TEXT_FIELD_MAX_LENGTH - 3)] + "..."
source_code = step_instance.source_code
if source_code and len(source_code) > TEXT_FIELD_MAX_LENGTH:
source_code = source_code[: (TEXT_FIELD_MAX_LENGTH - 3)] + "..."
return docstring, source_code
def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
"""Creates a pipeline run or reuses an existing one.
Returns:
The created or existing pipeline run,
and a boolean indicating whether the run was created or reused.
"""
run_name = orchestrator_utils.get_run_name(
run_name_template=self._deployment.run_name_template
)
logger.debug("Creating pipeline run %s", run_name)
client = Client()
pipeline_run = PipelineRunRequest(
name=run_name,
orchestrator_run_id=self._orchestrator_run_id,
user=client.active_user.id,
workspace=client.active_workspace.id,
deployment=self._deployment.id,
pipeline=(
self._deployment.pipeline.id
if self._deployment.pipeline
else None
),
status=ExecutionStatus.RUNNING,
orchestrator_environment=get_run_environment_dict(),
start_time=datetime.utcnow(),
)
return client.zen_store.get_or_create_run(pipeline_run)
def _prepare(
self,
step_run: StepRunRequest,
) -> Tuple[bool, StepRunRequest]:
"""Prepares running the step.
Args:
step_run: The step to run.
Returns:
Tuple that specifies whether the step needs to be executed as
well as the response model of the registered step run.
"""
input_artifacts, parent_step_ids = input_utils.resolve_step_inputs(
step=self._step,
run_id=step_run.pipeline_run_id,
)
input_artifact_ids = {
input_name: artifact.id
for input_name, artifact in input_artifacts.items()
}
cache_key = cache_utils.generate_cache_key(
step=self._step,
input_artifact_ids=input_artifact_ids,
artifact_store=self._stack.artifact_store,
workspace_id=Client().active_workspace.id,
)
step_run.inputs = input_artifact_ids
step_run.parent_step_ids = parent_step_ids
step_run.cache_key = cache_key
cache_enabled = orchestrator_utils.is_setting_enabled(
is_enabled_on_step=self._step.config.enable_cache,
is_enabled_on_pipeline=self._deployment.pipeline_configuration.enable_cache,
)
step_cache = self._step.config.enable_cache
if step_cache is not None:
logger.info(
f"Caching {'`enabled`' if step_cache else '`disabled`'} "
f"explicitly for `{self._step_name}`."
)
execution_needed = True
if cache_enabled:
cached_step_run = cache_utils.get_cached_step_run(
cache_key=cache_key
)
if cached_step_run:
logger.info(f"Using cached version of `{self._step_name}`.")
execution_needed = False
cached_outputs = cached_step_run.outputs
step_run.original_step_run_id = cached_step_run.id
step_run.outputs = {
output_name: artifact.id
for output_name, artifact in cached_outputs.items()
}
step_run.status = ExecutionStatus.CACHED
step_run.end_time = step_run.start_time
return execution_needed, step_run
def _run_step(
self,
pipeline_run: PipelineRunResponse,
step_run: StepRunResponse,
force_write_logs: Callable[..., Any],
last_retry: bool = True,
) -> None:
"""Runs the current step.
Args:
pipeline_run: The model of the current pipeline run.
step_run: The model of the current step run.
force_write_logs: The context for the step logs.
last_retry: Whether this is the last retry of the step.
"""
# Prepare step run information.
step_run_info = StepRunInfo(
config=self._step.config,
pipeline=self._deployment.pipeline_configuration,
run_name=pipeline_run.name,
pipeline_step_name=self._step_name,
run_id=pipeline_run.id,
step_run_id=step_run.id,
force_write_logs=force_write_logs,
)
output_artifact_uris = output_utils.prepare_output_artifact_uris(
step_run=step_run, stack=self._stack, step=self._step
)
# Run the step.
start_time = time.time()
try:
if self._step.config.step_operator:
self._run_step_with_step_operator(
step_operator_name=self._step.config.step_operator,
step_run_info=step_run_info,
last_retry=last_retry,
)
else:
self._run_step_without_step_operator(
pipeline_run=pipeline_run,
step_run=step_run,
step_run_info=step_run_info,
input_artifacts=step_run.inputs,
output_artifact_uris=output_artifact_uris,
last_retry=last_retry,
)
except: # noqa: E722
output_utils.remove_artifact_dirs(
artifact_uris=list(output_artifact_uris.values())
)
raise
duration = time.time() - start_time
logger.info(
f"Step `{self._step_name}` has finished in "
f"`{string_utils.get_human_readable_time(duration)}`."
)
def _run_step_with_step_operator(
self,
step_operator_name: str,
step_run_info: StepRunInfo,
last_retry: bool,
) -> None:
"""Runs the current step with a step operator.
Args:
step_operator_name: The name of the step operator to use.
step_run_info: Additional information needed to run the step.
last_retry: Whether this is the last retry of the step.
"""
step_operator = _get_step_operator(
stack=self._stack,
step_operator_name=step_operator_name,
)
entrypoint_cfg_class = step_operator.entrypoint_config_class
entrypoint_command = (
entrypoint_cfg_class.get_entrypoint_command()
+ entrypoint_cfg_class.get_entrypoint_arguments(
step_name=self._step_name,
deployment_id=self._deployment.id,
step_run_id=str(step_run_info.step_run_id),
)
)
environment = orchestrator_utils.get_config_environment_vars(
deployment=self._deployment
)
if last_retry:
environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False)
logger.info(
"Using step operator `%s` to run step `%s`.",
step_operator.name,
self._step_name,
)
step_operator.launch(
info=step_run_info,
entrypoint_command=entrypoint_command,
environment=environment,
)
def _run_step_without_step_operator(
self,
pipeline_run: PipelineRunResponse,
step_run: StepRunResponse,
step_run_info: StepRunInfo,
input_artifacts: Dict[str, ArtifactVersionResponse],
output_artifact_uris: Dict[str, str],
last_retry: bool,
) -> None:
"""Runs the current step without a step operator.
Args:
pipeline_run: The model of the current pipeline run.
step_run: The model of the current step run.
step_run_info: Additional information needed to run the step.
input_artifacts: The input artifact versions of the current step.
output_artifact_uris: The output artifact URIs of the current step.
last_retry: Whether this is the last retry of the step.
"""
if last_retry:
os.environ[ENV_ZENML_IGNORE_FAILURE_HOOK] = "false"
runner = StepRunner(step=self._step, stack=self._stack)
runner.run(
pipeline_run=pipeline_run,
step_run=step_run,
input_artifacts=input_artifacts,
output_artifact_uris=output_artifact_uris,
step_run_info=step_run_info,
)
__init__(self, deployment, step, orchestrator_run_id)
special
Initializes the launcher.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment. |
required |
step |
Step |
The step to launch. |
required |
orchestrator_run_id |
str |
The orchestrator pipeline run id. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the deployment has no associated stack. |
Source code in zenml/orchestrators/step_launcher.py
def __init__(
self,
deployment: PipelineDeploymentResponse,
step: Step,
orchestrator_run_id: str,
):
"""Initializes the launcher.
Args:
deployment: The pipeline deployment.
step: The step to launch.
orchestrator_run_id: The orchestrator pipeline run id.
Raises:
RuntimeError: If the deployment has no associated stack.
"""
self._deployment = deployment
self._step = step
self._orchestrator_run_id = orchestrator_run_id
if not deployment.stack:
raise RuntimeError(
f"Missing stack for deployment {deployment.id}. This is "
"probably because the stack was manually deleted."
)
self._stack = Stack.from_model(deployment.stack)
self._step_name = step.spec.pipeline_parameter_name
launch(self)
Launches the step.
Exceptions:
Type | Description |
---|---|
BaseException |
If the step failed to launch, run, or publish. |
Source code in zenml/orchestrators/step_launcher.py
def launch(self) -> None:
"""Launches the step.
Raises:
BaseException: If the step failed to launch, run, or publish.
"""
pipeline_run, run_was_created = self._create_or_reuse_run()
# Enable or disable step logs storage
if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False):
step_logging_enabled = False
else:
step_logging_enabled = orchestrator_utils.is_setting_enabled(
is_enabled_on_step=self._step.config.enable_step_logs,
is_enabled_on_pipeline=self._deployment.pipeline_configuration.enable_step_logs,
)
logs_context = nullcontext()
logs_model = None
if step_logging_enabled:
# Configure the logs
logs_uri = step_logging.prepare_logs_uri(
self._stack.artifact_store,
self._step.config.name,
)
logs_context = step_logging.StepLogsStorageContext(
logs_uri=logs_uri
) # type: ignore[assignment]
logs_model = LogsRequest(
uri=logs_uri,
artifact_store_id=self._stack.artifact_store.id,
)
try:
with logs_context:
if run_was_created:
pipeline_run_metadata = (
self._stack.get_pipeline_run_metadata(
run_id=pipeline_run.id
)
)
publish_utils.publish_pipeline_run_metadata(
pipeline_run_id=pipeline_run.id,
pipeline_run_metadata=pipeline_run_metadata,
)
client = Client()
(
docstring,
source_code,
) = self._get_step_docstring_and_source_code()
code_hash = self._deployment.step_configurations[
self._step_name
].config.caching_parameters.get(STEP_SOURCE_PARAMETER_NAME)
step_run = StepRunRequest(
name=self._step_name,
pipeline_run_id=pipeline_run.id,
deployment=self._deployment.id,
code_hash=code_hash,
status=ExecutionStatus.RUNNING,
docstring=docstring,
source_code=source_code,
start_time=datetime.utcnow(),
user=client.active_user.id,
workspace=client.active_workspace.id,
logs=logs_model,
)
try:
execution_needed, step_run = self._prepare(
step_run=step_run
)
except:
logger.exception(
f"Failed preparing run step `{self._step_name}`."
)
step_run.status = ExecutionStatus.FAILED
step_run.end_time = datetime.utcnow()
raise
finally:
step_run_response = Client().zen_store.create_run_step(
step_run
)
# warm-up and register model version
_step_run = None
model = (
self._deployment.step_configurations[
step_run.name
].config.model
or self._deployment.pipeline_configuration.model
)
if self._deployment.step_configurations[
step_run.name
].config.model:
_step_run = step_run_response
if model:
prep_logs_to_show = (
model._prepare_model_version_before_step_launch(
pipeline_run=pipeline_run,
step_run=_step_run,
return_logs=True,
)
)
if prep_logs_to_show:
logger.info(prep_logs_to_show)
logger.info(f"Step `{self._step_name}` has started.")
if execution_needed:
retries = 0
last_retry = True
max_retries = (
step_run_response.config.retry.max_retries
if step_run_response.config.retry
else 1
)
delay = (
step_run_response.config.retry.delay
if step_run_response.config.retry
else 0
)
backoff = (
step_run_response.config.retry.backoff
if step_run_response.config.retry
else 1
)
while retries < max_retries:
last_retry = retries == max_retries - 1
try:
# here pass a forced save_to_file callable to be
# used as a dump function to use before starting
# the external jobs in step operators
if isinstance(
logs_context,
step_logging.StepLogsStorageContext,
):
force_write_logs = partial(
logs_context.storage.save_to_file,
force=True,
)
else:
def _bypass() -> None:
return None
force_write_logs = _bypass
self._run_step(
pipeline_run=pipeline_run,
step_run=step_run_response,
last_retry=last_retry,
force_write_logs=force_write_logs,
)
logger.info(
f"Step `{self._step_name}` completed successfully."
)
break
except BaseException as e: # noqa: E722
retries += 1
if retries < max_retries:
logger.error(
f"Failed to run step `{self._step_name}`. Retrying..."
)
logger.exception(e)
logger.info(
f"Sleeping for {delay} seconds before retrying."
)
time.sleep(delay)
delay *= backoff
else:
logger.error(
f"Failed to run step `{self._step_name}` after {max_retries} retries. Exiting."
)
logger.exception(e)
publish_utils.publish_failed_step_run(
step_run_response.id
)
raise
else:
orchestrator_utils._link_cached_artifacts_to_model(
model_from_context=model,
step_run=step_run,
step_source=self._step.spec.source,
)
if model:
orchestrator_utils._link_pipeline_run_to_model_from_context(
pipeline_run_id=step_run.pipeline_run_id,
model=model,
)
except: # noqa: E722
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
publish_utils.publish_failed_pipeline_run(pipeline_run.id)
raise
step_runner
Class to run steps.
StepRunner
Class to run steps.
Source code in zenml/orchestrators/step_runner.py
class StepRunner:
"""Class to run steps."""
def __init__(self, step: "Step", stack: "Stack"):
"""Initializes the step runner.
Args:
step: The step to run.
stack: The stack on which the step should run.
"""
self._step = step
self._stack = stack
@property
def configuration(self) -> StepConfiguration:
"""Configuration of the step to run.
Returns:
The step configuration.
"""
return self._step.config
def run(
self,
pipeline_run: "PipelineRunResponse",
step_run: "StepRunResponse",
input_artifacts: Dict[str, "ArtifactVersionResponse"],
output_artifact_uris: Dict[str, str],
step_run_info: StepRunInfo,
) -> None:
"""Runs the step.
Args:
pipeline_run: The model of the current pipeline run.
step_run: The model of the current step run.
input_artifacts: The input artifact versions of the step.
output_artifact_uris: The URIs of the output artifacts of the step.
step_run_info: The step run info.
Raises:
BaseException: A general exception if the step fails.
"""
if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False):
step_logging_enabled = False
else:
enabled_on_step = step_run.config.enable_step_logs
enabled_on_pipeline = pipeline_run.config.enable_step_logs
step_logging_enabled = is_setting_enabled(
is_enabled_on_step=enabled_on_step,
is_enabled_on_pipeline=enabled_on_pipeline,
)
logs_context = nullcontext()
if step_logging_enabled and not redirected.get():
if step_run.logs:
logs_context = StepLogsStorageContext( # type: ignore[assignment]
logs_uri=step_run.logs.uri
)
else:
logger.debug(
"There is no LogsResponseModel prepared for the step. The"
"step logging storage is disabled."
)
with logs_context:
step_instance = self._load_step()
output_materializers = self._load_output_materializers()
spec = inspect.getfullargspec(
inspect.unwrap(step_instance.entrypoint)
)
# (Deprecated) Wrap the execution of the step function in a step
# environment that the step function code can access to retrieve
# information about the pipeline runtime, such as the current step
# name and the current pipeline run ID
cache_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_cache,
is_enabled_on_pipeline=step_run_info.pipeline.enable_cache,
)
output_annotations = parse_return_type_annotations(
func=step_instance.entrypoint
)
with StepEnvironment(
step_run_info=step_run_info,
cache_enabled=cache_enabled,
):
self._stack.prepare_step_run(info=step_run_info)
# Initialize the step context singleton
StepContext._clear()
StepContext(
pipeline_run=pipeline_run,
step_run=step_run,
output_materializers=output_materializers,
output_artifact_uris=output_artifact_uris,
step_run_info=step_run_info,
cache_enabled=cache_enabled,
output_artifact_configs={
k: v.artifact_config
for k, v in output_annotations.items()
},
)
# Parse the inputs for the entrypoint function.
function_params = self._parse_inputs(
args=spec.args,
annotations=spec.annotations,
input_artifacts=input_artifacts,
)
_link_pipeline_run_to_model_from_context(
pipeline_run_id=pipeline_run.id
)
step_failed = False
try:
return_values = step_instance.call_entrypoint(
**function_params
)
except BaseException as step_exception: # noqa: E722
step_failed = True
if not handle_bool_env_var(
ENV_ZENML_IGNORE_FAILURE_HOOK, False
):
if (
failure_hook_source
:= self.configuration.failure_hook_source
):
logger.info("Detected failure hook. Running...")
self.load_and_run_hook(
failure_hook_source,
step_exception=step_exception,
)
raise
finally:
step_run_metadata = self._stack.get_step_run_metadata(
info=step_run_info,
)
publish_step_run_metadata(
step_run_id=step_run_info.step_run_id,
step_run_metadata=step_run_metadata,
)
self._stack.cleanup_step_run(
info=step_run_info, step_failed=step_failed
)
if not step_failed:
if (
success_hook_source
:= self.configuration.success_hook_source
):
logger.info("Detected success hook. Running...")
self.load_and_run_hook(
success_hook_source,
step_exception=None,
)
# Store and publish the output artifacts of the step function.
output_data = self._validate_outputs(
return_values, output_annotations
)
artifact_metadata_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
)
artifact_visualization_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
)
output_artifact_ids = self._store_output_artifacts(
output_data=output_data,
output_artifact_uris=output_artifact_uris,
output_materializers=output_materializers,
output_annotations=output_annotations,
artifact_metadata_enabled=artifact_metadata_enabled,
artifact_visualization_enabled=artifact_visualization_enabled,
)
link_step_artifacts_to_model(
artifact_version_ids=output_artifact_ids
)
_link_pipeline_run_to_model_from_artifacts(
pipeline_run_id=pipeline_run.id,
artifact_names=list(output_artifact_ids.keys()),
external_artifacts=list(
step_run.config.external_input_artifacts.values()
),
)
StepContext._clear() # Remove the step context singleton
# Update the status and output artifacts of the step run.
publish_successful_step_run(
step_run_id=step_run_info.step_run_id,
output_artifact_ids=output_artifact_ids,
)
def _load_step(self) -> "BaseStep":
"""Load the step instance.
Returns:
The step instance.
"""
from zenml.steps import BaseStep
step_instance = BaseStep.load_from_source(self._step.spec.source)
step_instance = copy.deepcopy(step_instance)
step_instance._configuration = self._step.config
return step_instance
def _load_output_materializers(
self,
) -> Dict[str, Tuple[Type[BaseMaterializer], ...]]:
"""Loads the output materializers for the step.
Returns:
The step output materializers.
"""
materializers = {}
for name, output in self.configuration.outputs.items():
output_materializers = []
for source in output.materializer_source:
materializer_class: Type[BaseMaterializer] = (
source_utils.load_and_validate_class(
source, expected_class=BaseMaterializer
)
)
output_materializers.append(materializer_class)
materializers[name] = tuple(output_materializers)
return materializers
def _parse_inputs(
self,
args: List[str],
annotations: Dict[str, Any],
input_artifacts: Dict[str, "ArtifactVersionResponse"],
) -> Dict[str, Any]:
"""Parses the inputs for a step entrypoint function.
Args:
args: The arguments of the step entrypoint function.
annotations: The annotations of the step entrypoint function.
input_artifacts: The input artifact versions of the step.
Returns:
The parsed inputs for the step entrypoint function.
Raises:
RuntimeError: If a function argument value is missing.
"""
function_params: Dict[str, Any] = {}
if args and args[0] == "self":
args.pop(0)
for arg in args:
arg_type = annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)
if inspect.isclass(arg_type) and issubclass(arg_type, StepContext):
step_name = self.configuration.name
logger.warning(
"Passing a `StepContext` as an argument to a step function "
"is deprecated and will be removed in a future release. "
f"Please adjust your '{step_name}' step to instead import "
"the `StepContext` inside your step, as shown here: "
"https://docs.zenml.io/how-to/track-metrics-metadata/fetch-metadata-within-steps"
)
function_params[arg] = get_step_context()
elif arg in input_artifacts:
function_params[arg] = self._load_input_artifact(
input_artifacts[arg], arg_type
)
elif arg in self.configuration.parameters:
function_params[arg] = self.configuration.parameters[arg]
else:
raise RuntimeError(
f"Unable to find value for step function argument `{arg}`."
)
return function_params
def _parse_hook_inputs(
self,
args: List[str],
annotations: Dict[str, Any],
step_exception: Optional[BaseException],
) -> Dict[str, Any]:
"""Parses the inputs for a hook function.
Args:
args: The arguments of the hook function.
annotations: The annotations of the hook function.
step_exception: The exception of the original step.
Returns:
The parsed inputs for the hook function.
Raises:
TypeError: If hook function is passed a wrong parameter type.
"""
from zenml.steps import BaseParameters
function_params: Dict[str, Any] = {}
if args and args[0] == "self":
args.pop(0)
for arg in args:
arg_type = annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)
# Parse the parameters
if issubclass(arg_type, BaseParameters):
step_params = arg_type.model_validate(
self.configuration.parameters[arg]
)
function_params[arg] = step_params
# Parse the step context
elif issubclass(arg_type, StepContext):
step_name = self.configuration.name
logger.warning(
"Passing a `StepContext` as an argument to a hook function "
"is deprecated and will be removed in a future release. "
f"Please adjust your '{step_name}' hook to instead import "
"the `StepContext` inside your hook, as shown here: "
"https://docs.zenml.io/how-to/track-metrics-metadata/fetch-metadata-within-steps"
)
function_params[arg] = get_step_context()
elif issubclass(arg_type, BaseException):
function_params[arg] = step_exception
else:
# It should not be of any other type
raise TypeError(
"Hook functions can only take arguments of type "
f"`BaseParameters`, or `BaseException`, not {arg_type}"
)
return function_params
def _load_input_artifact(
self, artifact: "ArtifactVersionResponse", data_type: Type[Any]
) -> Any:
"""Loads an input artifact.
Args:
artifact: The artifact to load.
data_type: The data type of the artifact value.
Returns:
The artifact value.
"""
# Skip materialization for `UnmaterializedArtifact`.
if data_type == UnmaterializedArtifact:
return UnmaterializedArtifact(
**artifact.get_hydrated_version().model_dump()
)
if data_type is Any or is_union(get_origin(data_type)):
# Entrypoint function does not define a specific type for the input,
# we use the datatype of the stored artifact
data_type = source_utils.load(artifact.data_type)
from zenml.orchestrators.utils import (
register_artifact_store_filesystem,
)
materializer_class: Type[BaseMaterializer] = (
source_utils.load_and_validate_class(
artifact.materializer, expected_class=BaseMaterializer
)
)
with register_artifact_store_filesystem(
artifact.artifact_store_id
) as target_artifact_store:
materializer: BaseMaterializer = materializer_class(
uri=artifact.uri, artifact_store=target_artifact_store
)
materializer.validate_type_compatibility(data_type)
return materializer.load(data_type=data_type)
def _validate_outputs(
self,
return_values: Any,
output_annotations: Dict[str, OutputSignature],
) -> Dict[str, Any]:
"""Validates the step function outputs.
Args:
return_values: The return values of the step function.
output_annotations: The output annotations of the step function.
Returns:
The validated output, mapping output names to return values.
Raises:
StepInterfaceError: If the step function return values do not
match the output annotations.
"""
step_name = self._step.spec.pipeline_parameter_name
# if there are no outputs, the return value must be `None`.
if len(output_annotations) == 0:
if return_values is not None:
raise StepInterfaceError(
f"Wrong step function output type for step `{step_name}`: "
f"Expected no outputs but the function returned something: "
f"{return_values}."
)
return {}
# if there is only one output annotation (either directly specified
# or contained in an `Output` tuple) we treat the step function
# return value as the return for that output.
if len(output_annotations) == 1:
return_values = [return_values]
# if the user defined multiple outputs, the return value must be a list
# or tuple.
if not isinstance(return_values, (list, tuple)):
raise StepInterfaceError(
f"Wrong step function output type for step `{step_name}`: "
f"Expected multiple outputs ({output_annotations}) but "
f"the function did not return a list or tuple "
f"(actual return value: {return_values})."
)
# The amount of actual outputs must be the same as the amount of
# expected outputs.
if len(output_annotations) != len(return_values):
raise StepInterfaceError(
f"Wrong amount of step function outputs for step "
f"'{step_name}: Expected {len(output_annotations)} outputs "
f"but the function returned {len(return_values)} outputs"
f"(return values: {return_values})."
)
from zenml.steps.utils import get_args
validated_outputs: Dict[str, Any] = {}
for return_value, (output_name, output_annotation) in zip(
return_values, output_annotations.items()
):
output_type = output_annotation.resolved_annotation
if output_type is Any:
pass
else:
if is_union(get_origin(output_type)):
output_type = get_args(output_type)
if not isinstance(return_value, output_type):
raise StepInterfaceError(
f"Wrong type for output '{output_name}' of step "
f"'{step_name}' (expected type: {output_type}, "
f"actual type: {type(return_value)})."
)
validated_outputs[output_name] = return_value
return validated_outputs
def _store_output_artifacts(
self,
output_data: Dict[str, Any],
output_materializers: Dict[str, Tuple[Type[BaseMaterializer], ...]],
output_artifact_uris: Dict[str, str],
output_annotations: Dict[str, OutputSignature],
artifact_metadata_enabled: bool,
artifact_visualization_enabled: bool,
) -> Dict[str, UUID]:
"""Stores the output artifacts of the step.
Args:
output_data: The output data of the step function, mapping output
names to return values.
output_materializers: The output materializers of the step.
output_artifact_uris: The output artifact URIs of the step.
output_annotations: The output annotations of the step function.
artifact_metadata_enabled: Whether artifact metadata collection is
enabled.
artifact_visualization_enabled: Whether artifact visualization is
enabled.
Returns:
The IDs of the published output artifacts.
"""
step_context = get_step_context()
output_artifacts: Dict[str, UUID] = {}
for output_name, return_value in output_data.items():
data_type = type(return_value)
materializer_classes = output_materializers[output_name]
if materializer_classes:
materializer_class = materializer_utils.select_materializer(
data_type=data_type,
materializer_classes=materializer_classes,
)
else:
# If no materializer classes are stored in the IR, that means
# there was no/an `Any` type annotation for the output and
# we try to find a materializer for it at runtime
from zenml.materializers.materializer_registry import (
materializer_registry,
)
default_materializer_source = self._step.config.outputs[
output_name
].default_materializer_source
if default_materializer_source:
default_materializer_class: Type[BaseMaterializer] = (
source_utils.load_and_validate_class(
default_materializer_source,
expected_class=BaseMaterializer,
)
)
materializer_registry.default_materializer = (
default_materializer_class
)
materializer_class = materializer_registry[data_type]
uri = output_artifact_uris[output_name]
artifact_config = output_annotations[output_name].artifact_config
if artifact_config is not None:
has_custom_name = bool(artifact_config.name)
version = artifact_config.version
else:
has_custom_name, version = False, None
# Override the artifact name if it is not a custom name.
if has_custom_name:
artifact_name = output_name
else:
if step_context.pipeline_run.pipeline:
pipeline_name = step_context.pipeline_run.pipeline.name
else:
pipeline_name = "unlisted"
step_name = step_context.step_run.name
artifact_name = f"{pipeline_name}::{step_name}::{output_name}"
# Get metadata that the user logged manually
user_metadata = step_context.get_output_metadata(output_name)
# Get full set of tags
tags = step_context.get_output_tags(output_name)
artifact = save_artifact(
name=artifact_name,
data=return_value,
materializer=materializer_class,
uri=uri,
extract_metadata=artifact_metadata_enabled,
include_visualizations=artifact_visualization_enabled,
has_custom_name=has_custom_name,
version=version,
tags=tags,
user_metadata=user_metadata,
manual_save=False,
)
output_artifacts[output_name] = artifact.id
return output_artifacts
def load_and_run_hook(
self,
hook_source: "Source",
step_exception: Optional[BaseException],
) -> None:
"""Loads hook source and runs the hook.
Args:
hook_source: The source of the hook function.
step_exception: The exception of the original step.
"""
try:
hook = source_utils.load(hook_source)
hook_spec = inspect.getfullargspec(inspect.unwrap(hook))
function_params = self._parse_hook_inputs(
args=hook_spec.args,
annotations=hook_spec.annotations,
step_exception=step_exception,
)
logger.debug(f"Running hook {hook} with params: {function_params}")
hook(**function_params)
except Exception as e:
logger.error(
f"Failed to load hook source with exception: '{hook_source}': "
f"{e}"
)
configuration: StepConfiguration
property
readonly
Configuration of the step to run.
Returns:
Type | Description |
---|---|
StepConfiguration |
The step configuration. |
__init__(self, step, stack)
special
Initializes the step runner.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
Step |
The step to run. |
required |
stack |
Stack |
The stack on which the step should run. |
required |
Source code in zenml/orchestrators/step_runner.py
def __init__(self, step: "Step", stack: "Stack"):
"""Initializes the step runner.
Args:
step: The step to run.
stack: The stack on which the step should run.
"""
self._step = step
self._stack = stack
load_and_run_hook(self, hook_source, step_exception)
Loads hook source and runs the hook.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hook_source |
Source |
The source of the hook function. |
required |
step_exception |
Optional[BaseException] |
The exception of the original step. |
required |
Source code in zenml/orchestrators/step_runner.py
def load_and_run_hook(
self,
hook_source: "Source",
step_exception: Optional[BaseException],
) -> None:
"""Loads hook source and runs the hook.
Args:
hook_source: The source of the hook function.
step_exception: The exception of the original step.
"""
try:
hook = source_utils.load(hook_source)
hook_spec = inspect.getfullargspec(inspect.unwrap(hook))
function_params = self._parse_hook_inputs(
args=hook_spec.args,
annotations=hook_spec.annotations,
step_exception=step_exception,
)
logger.debug(f"Running hook {hook} with params: {function_params}")
hook(**function_params)
except Exception as e:
logger.error(
f"Failed to load hook source with exception: '{hook_source}': "
f"{e}"
)
run(self, pipeline_run, step_run, input_artifacts, output_artifact_uris, step_run_info)
Runs the step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunResponse |
The model of the current pipeline run. |
required |
step_run |
StepRunResponse |
The model of the current step run. |
required |
input_artifacts |
Dict[str, ArtifactVersionResponse] |
The input artifact versions of the step. |
required |
output_artifact_uris |
Dict[str, str] |
The URIs of the output artifacts of the step. |
required |
step_run_info |
StepRunInfo |
The step run info. |
required |
Exceptions:
Type | Description |
---|---|
BaseException |
A general exception if the step fails. |
Source code in zenml/orchestrators/step_runner.py
def run(
self,
pipeline_run: "PipelineRunResponse",
step_run: "StepRunResponse",
input_artifacts: Dict[str, "ArtifactVersionResponse"],
output_artifact_uris: Dict[str, str],
step_run_info: StepRunInfo,
) -> None:
"""Runs the step.
Args:
pipeline_run: The model of the current pipeline run.
step_run: The model of the current step run.
input_artifacts: The input artifact versions of the step.
output_artifact_uris: The URIs of the output artifacts of the step.
step_run_info: The step run info.
Raises:
BaseException: A general exception if the step fails.
"""
if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False):
step_logging_enabled = False
else:
enabled_on_step = step_run.config.enable_step_logs
enabled_on_pipeline = pipeline_run.config.enable_step_logs
step_logging_enabled = is_setting_enabled(
is_enabled_on_step=enabled_on_step,
is_enabled_on_pipeline=enabled_on_pipeline,
)
logs_context = nullcontext()
if step_logging_enabled and not redirected.get():
if step_run.logs:
logs_context = StepLogsStorageContext( # type: ignore[assignment]
logs_uri=step_run.logs.uri
)
else:
logger.debug(
"There is no LogsResponseModel prepared for the step. The"
"step logging storage is disabled."
)
with logs_context:
step_instance = self._load_step()
output_materializers = self._load_output_materializers()
spec = inspect.getfullargspec(
inspect.unwrap(step_instance.entrypoint)
)
# (Deprecated) Wrap the execution of the step function in a step
# environment that the step function code can access to retrieve
# information about the pipeline runtime, such as the current step
# name and the current pipeline run ID
cache_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_cache,
is_enabled_on_pipeline=step_run_info.pipeline.enable_cache,
)
output_annotations = parse_return_type_annotations(
func=step_instance.entrypoint
)
with StepEnvironment(
step_run_info=step_run_info,
cache_enabled=cache_enabled,
):
self._stack.prepare_step_run(info=step_run_info)
# Initialize the step context singleton
StepContext._clear()
StepContext(
pipeline_run=pipeline_run,
step_run=step_run,
output_materializers=output_materializers,
output_artifact_uris=output_artifact_uris,
step_run_info=step_run_info,
cache_enabled=cache_enabled,
output_artifact_configs={
k: v.artifact_config
for k, v in output_annotations.items()
},
)
# Parse the inputs for the entrypoint function.
function_params = self._parse_inputs(
args=spec.args,
annotations=spec.annotations,
input_artifacts=input_artifacts,
)
_link_pipeline_run_to_model_from_context(
pipeline_run_id=pipeline_run.id
)
step_failed = False
try:
return_values = step_instance.call_entrypoint(
**function_params
)
except BaseException as step_exception: # noqa: E722
step_failed = True
if not handle_bool_env_var(
ENV_ZENML_IGNORE_FAILURE_HOOK, False
):
if (
failure_hook_source
:= self.configuration.failure_hook_source
):
logger.info("Detected failure hook. Running...")
self.load_and_run_hook(
failure_hook_source,
step_exception=step_exception,
)
raise
finally:
step_run_metadata = self._stack.get_step_run_metadata(
info=step_run_info,
)
publish_step_run_metadata(
step_run_id=step_run_info.step_run_id,
step_run_metadata=step_run_metadata,
)
self._stack.cleanup_step_run(
info=step_run_info, step_failed=step_failed
)
if not step_failed:
if (
success_hook_source
:= self.configuration.success_hook_source
):
logger.info("Detected success hook. Running...")
self.load_and_run_hook(
success_hook_source,
step_exception=None,
)
# Store and publish the output artifacts of the step function.
output_data = self._validate_outputs(
return_values, output_annotations
)
artifact_metadata_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
)
artifact_visualization_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
)
output_artifact_ids = self._store_output_artifacts(
output_data=output_data,
output_artifact_uris=output_artifact_uris,
output_materializers=output_materializers,
output_annotations=output_annotations,
artifact_metadata_enabled=artifact_metadata_enabled,
artifact_visualization_enabled=artifact_visualization_enabled,
)
link_step_artifacts_to_model(
artifact_version_ids=output_artifact_ids
)
_link_pipeline_run_to_model_from_artifacts(
pipeline_run_id=pipeline_run.id,
artifact_names=list(output_artifact_ids.keys()),
external_artifacts=list(
step_run.config.external_input_artifacts.values()
),
)
StepContext._clear() # Remove the step context singleton
# Update the status and output artifacts of the step run.
publish_successful_step_run(
step_run_id=step_run_info.step_run_id,
output_artifact_ids=output_artifact_ids,
)
topsort
Utilities for topological sort.
Implementation heavily inspired by TFX: https://github.com/tensorflow/tfx/blob/master/tfx/utils/topsort.py
topsorted_layers(nodes, get_node_id_fn, get_parent_nodes, get_child_nodes)
Sorts the DAG of nodes in topological order.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nodes |
Sequence[~NodeT] |
A sequence of nodes. |
required |
get_node_id_fn |
Callable[[~NodeT], str] |
Callable that returns a unique text identifier for a node. |
required |
get_parent_nodes |
Callable[[~NodeT], List[~NodeT]] |
Callable that returns a list of parent nodes for a node. If a parent node's id is not found in the list of node ids, that parent node will be omitted. |
required |
get_child_nodes |
Callable[[~NodeT], List[~NodeT]] |
Callable that returns a list of child nodes for a node. If a child node's id is not found in the list of node ids, that child node will be omitted. |
required |
Returns:
Type | Description |
---|---|
List[List[~NodeT]] |
A list of topologically ordered node layers. Each layer of nodes is sorted
by its node id given by |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the input nodes don't form a DAG. |
ValueError |
If the nodes are not unique. |
Source code in zenml/orchestrators/topsort.py
def topsorted_layers(
nodes: Sequence[NodeT],
get_node_id_fn: Callable[[NodeT], str],
get_parent_nodes: Callable[[NodeT], List[NodeT]],
get_child_nodes: Callable[[NodeT], List[NodeT]],
) -> List[List[NodeT]]:
"""Sorts the DAG of nodes in topological order.
Args:
nodes: A sequence of nodes.
get_node_id_fn: Callable that returns a unique text identifier for a node.
get_parent_nodes: Callable that returns a list of parent nodes for a node.
If a parent node's id is not found in the list of node ids, that parent
node will be omitted.
get_child_nodes: Callable that returns a list of child nodes for a node.
If a child node's id is not found in the list of node ids, that child
node will be omitted.
Returns:
A list of topologically ordered node layers. Each layer of nodes is sorted
by its node id given by `get_node_id_fn`.
Raises:
RuntimeError: If the input nodes don't form a DAG.
ValueError: If the nodes are not unique.
"""
# Make sure the nodes are unique.
node_ids = set(get_node_id_fn(n) for n in nodes)
if len(node_ids) != len(nodes):
raise ValueError("Nodes must have unique ids.")
# The outputs of get_(parent|child)_nodes should always be deduplicated,
# and references to unknown nodes should be removed.
def _apply_and_clean(
func: Callable[[NodeT], List[NodeT]], func_name: str, node: NodeT
) -> List[NodeT]:
seen_inner_node_ids = set()
result = []
for inner_node in func(node):
inner_node_id = get_node_id_fn(inner_node)
if inner_node_id in seen_inner_node_ids:
logger.warning(
"Duplicate node_id %s found when calling %s on node %s. "
"This entry will be ignored.",
inner_node_id,
func_name,
node,
)
elif inner_node_id not in node_ids:
logger.warning(
"node_id %s found when calling %s on node %s, but this node_id is "
"not found in the set of input nodes %s. This entry will be "
"ignored.",
inner_node_id,
func_name,
node,
node_ids,
)
else:
seen_inner_node_ids.add(inner_node_id)
result.append(inner_node)
return result
def get_clean_parent_nodes(node: NodeT) -> List[NodeT]:
return _apply_and_clean(get_parent_nodes, "get_parent_nodes", node)
def get_clean_child_nodes(node: NodeT) -> List[NodeT]:
return _apply_and_clean(get_child_nodes, "get_child_nodes", node)
# The first layer contains nodes with no incoming edges.
layer = [node for node in nodes if not get_clean_parent_nodes(node)]
visited_node_ids = set()
layers = []
while layer:
layer = sorted(layer, key=get_node_id_fn)
layers.append(layer)
next_layer = []
for node in layer:
visited_node_ids.add(get_node_id_fn(node))
for child_node in get_clean_child_nodes(node):
# Include the child node if all its parents are visited. If the child
# node is part of a cycle, it will never be included since it will have
# at least one unvisited parent node which is also part of the cycle.
parent_node_ids = set(
get_node_id_fn(p)
for p in get_clean_parent_nodes(child_node)
)
if parent_node_ids.issubset(visited_node_ids):
next_layer.append(child_node)
layer = next_layer
num_output_nodes = sum(len(layer) for layer in layers)
# Nodes in cycles are not included in layers; raise an error if this happens.
if num_output_nodes < len(nodes):
raise RuntimeError("Cannot sort graph because it contains a cycle.")
# This should never happen; raise an error if this occurs.
if num_output_nodes > len(nodes):
raise RuntimeError("Unknown error occurred while sorting DAG.")
return layers
utils
Utility functions for the orchestrator.
register_artifact_store_filesystem
Context manager for the artifact_store/filesystem_registry dependency.
Even though it is rare, sometimes we bump into cases where we are trying to load artifacts that belong to an artifact store which is different from the active artifact store.
In cases like this, we will try to instantiate the target artifact store by creating the corresponding artifact store Python object, which ends up registering the right filesystem in the filesystem registry.
The problem is, the keys in the filesystem registry are schemes (such as "s3://" or "gcs://"). If we have two artifact stores with the same set of supported schemes, we might end up overwriting the filesystem that belongs to the active artifact store (and its authentication). That's why we have to re-instantiate the active artifact store again, so the correct filesystem will be restored.
Source code in zenml/orchestrators/utils.py
class register_artifact_store_filesystem:
"""Context manager for the artifact_store/filesystem_registry dependency.
Even though it is rare, sometimes we bump into cases where we are trying to
load artifacts that belong to an artifact store which is different from
the active artifact store.
In cases like this, we will try to instantiate the target artifact store
by creating the corresponding artifact store Python object, which ends up
registering the right filesystem in the filesystem registry.
The problem is, the keys in the filesystem registry are schemes (such as
"s3://" or "gcs://"). If we have two artifact stores with the same set of
supported schemes, we might end up overwriting the filesystem that belongs
to the active artifact store (and its authentication). That's why we have
to re-instantiate the active artifact store again, so the correct filesystem
will be restored.
"""
def __init__(self, target_artifact_store_id: Optional[UUID]) -> None:
"""Initialization of the context manager.
Args:
target_artifact_store_id: the ID of the artifact store to load.
"""
self.target_artifact_store_id = target_artifact_store_id
def __enter__(self) -> "BaseArtifactStore":
"""Entering the context manager.
It creates an instance of the target artifact store to register the
correct filesystem in the registry.
Returns:
The target artifact store object.
Raises:
RuntimeError: If the target artifact store can not be fetched or
initiated due to missing dependencies.
"""
try:
if self.target_artifact_store_id is not None:
if (
Client().active_stack.artifact_store.id
!= self.target_artifact_store_id
):
get_logger(__name__).debug(
f"Trying to use the artifact store with ID:"
f"'{self.target_artifact_store_id}'"
f"which is currently not the active artifact store."
)
artifact_store_model_response = Client().get_stack_component(
component_type=StackComponentType.ARTIFACT_STORE,
name_id_or_prefix=self.target_artifact_store_id,
)
return cast(
"BaseArtifactStore",
StackComponent.from_model(artifact_store_model_response),
)
else:
return Client().active_stack.artifact_store
except KeyError:
raise RuntimeError(
"Unable to fetch the artifact store with id: "
f"'{self.target_artifact_store_id}'. Check whether the "
"artifact store still exists and you have the right "
"permissions to access it."
)
except ImportError:
raise RuntimeError(
"Unable to load the implementation of the artifact store with"
f"id: '{self.target_artifact_store_id}'. Please make sure that "
"the environment that you are loading this artifact from "
"has the right dependencies."
)
def __exit__(
self,
exc_type: Optional[Any],
exc_value: Optional[Any],
traceback: Optional[Any],
) -> None:
"""Set it back to the original state.
Args:
exc_type: The class of the exception
exc_value: The instance of the exception
traceback: The traceback of the exception
"""
if ENV_ZENML_SERVER not in os.environ:
# As we exit the handler, we have to re-register the filesystem
# that belongs to the active artifact store as it may have been
# overwritten.
Client().active_stack.artifact_store._register()
__enter__(self)
special
Entering the context manager.
It creates an instance of the target artifact store to register the correct filesystem in the registry.
Returns:
Type | Description |
---|---|
BaseArtifactStore |
The target artifact store object. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the target artifact store can not be fetched or initiated due to missing dependencies. |
Source code in zenml/orchestrators/utils.py
def __enter__(self) -> "BaseArtifactStore":
"""Entering the context manager.
It creates an instance of the target artifact store to register the
correct filesystem in the registry.
Returns:
The target artifact store object.
Raises:
RuntimeError: If the target artifact store can not be fetched or
initiated due to missing dependencies.
"""
try:
if self.target_artifact_store_id is not None:
if (
Client().active_stack.artifact_store.id
!= self.target_artifact_store_id
):
get_logger(__name__).debug(
f"Trying to use the artifact store with ID:"
f"'{self.target_artifact_store_id}'"
f"which is currently not the active artifact store."
)
artifact_store_model_response = Client().get_stack_component(
component_type=StackComponentType.ARTIFACT_STORE,
name_id_or_prefix=self.target_artifact_store_id,
)
return cast(
"BaseArtifactStore",
StackComponent.from_model(artifact_store_model_response),
)
else:
return Client().active_stack.artifact_store
except KeyError:
raise RuntimeError(
"Unable to fetch the artifact store with id: "
f"'{self.target_artifact_store_id}'. Check whether the "
"artifact store still exists and you have the right "
"permissions to access it."
)
except ImportError:
raise RuntimeError(
"Unable to load the implementation of the artifact store with"
f"id: '{self.target_artifact_store_id}'. Please make sure that "
"the environment that you are loading this artifact from "
"has the right dependencies."
)
__exit__(self, exc_type, exc_value, traceback)
special
Set it back to the original state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
exc_type |
Optional[Any] |
The class of the exception |
required |
exc_value |
Optional[Any] |
The instance of the exception |
required |
traceback |
Optional[Any] |
The traceback of the exception |
required |
Source code in zenml/orchestrators/utils.py
def __exit__(
self,
exc_type: Optional[Any],
exc_value: Optional[Any],
traceback: Optional[Any],
) -> None:
"""Set it back to the original state.
Args:
exc_type: The class of the exception
exc_value: The instance of the exception
traceback: The traceback of the exception
"""
if ENV_ZENML_SERVER not in os.environ:
# As we exit the handler, we have to re-register the filesystem
# that belongs to the active artifact store as it may have been
# overwritten.
Client().active_stack.artifact_store._register()
__init__(self, target_artifact_store_id)
special
Initialization of the context manager.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_artifact_store_id |
Optional[uuid.UUID] |
the ID of the artifact store to load. |
required |
Source code in zenml/orchestrators/utils.py
def __init__(self, target_artifact_store_id: Optional[UUID]) -> None:
"""Initialization of the context manager.
Args:
target_artifact_store_id: the ID of the artifact store to load.
"""
self.target_artifact_store_id = target_artifact_store_id
get_config_environment_vars(deployment=None)
Gets environment variables to set for mirroring the active config.
If a pipeline deployment is given, the environment variables will be set to include a newly generated API token valid for the duration of the pipeline run instead of the API token from the global config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
Optional[PipelineDeploymentResponse] |
Optional deployment to use for the environment variables. |
None |
Returns:
Type | Description |
---|---|
Dict[str, str] |
Environment variable dict. |
Source code in zenml/orchestrators/utils.py
def get_config_environment_vars(
deployment: Optional["PipelineDeploymentResponse"] = None,
) -> Dict[str, str]:
"""Gets environment variables to set for mirroring the active config.
If a pipeline deployment is given, the environment variables will be set to
include a newly generated API token valid for the duration of the pipeline
run instead of the API token from the global config.
Args:
deployment: Optional deployment to use for the environment variables.
Returns:
Environment variable dict.
"""
from zenml.zen_stores.rest_zen_store import RestZenStore
global_config = GlobalConfiguration()
environment_vars = global_config.get_config_environment_vars()
if deployment and global_config.store_configuration.type == StoreType.REST:
# When connected to a ZenML server, if a pipeline deployment is
# supplied, we need to fetch an API token that will be valid for the
# duration of the pipeline run.
assert isinstance(global_config.zen_store, RestZenStore)
pipeline_id: Optional[UUID] = None
if deployment.pipeline:
pipeline_id = deployment.pipeline.id
schedule_id: Optional[UUID] = None
expires_minutes: Optional[int] = PIPELINE_API_TOKEN_EXPIRES_MINUTES
if deployment.schedule:
schedule_id = deployment.schedule.id
# If a schedule is given, this is a long running pipeline that
# should not have an API token that expires.
expires_minutes = None
api_token = global_config.zen_store.get_api_token(
pipeline_id=pipeline_id,
schedule_id=schedule_id,
expires_minutes=expires_minutes,
)
environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = api_token
# Make sure to use the correct active stack/workspace which might come
# from a .zen repository and not the global config
environment_vars[ENV_ZENML_ACTIVE_STACK_ID] = str(
Client().active_stack_model.id
)
environment_vars[ENV_ZENML_ACTIVE_WORKSPACE_ID] = str(
Client().active_workspace.id
)
return environment_vars
get_orchestrator_run_name(pipeline_name)
Gets an orchestrator run name.
This run name is not the same as the ZenML run name but can instead be used to display in the orchestrator UI.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline that will run. |
required |
Returns:
Type | Description |
---|---|
str |
The orchestrator run name. |
Source code in zenml/orchestrators/utils.py
def get_orchestrator_run_name(pipeline_name: str) -> str:
"""Gets an orchestrator run name.
This run name is not the same as the ZenML run name but can instead be
used to display in the orchestrator UI.
Args:
pipeline_name: Name of the pipeline that will run.
Returns:
The orchestrator run name.
"""
return f"{pipeline_name}_{random.Random().getrandbits(128):032x}"
get_run_name(run_name_template)
Fill out the run name template to get a complete run name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name_template |
str |
The run name template to fill out. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the run name is empty. |
Returns:
Type | Description |
---|---|
str |
The run name derived from the template. |
Source code in zenml/orchestrators/utils.py
def get_run_name(run_name_template: str) -> str:
"""Fill out the run name template to get a complete run name.
Args:
run_name_template: The run name template to fill out.
Raises:
ValueError: If the run name is empty.
Returns:
The run name derived from the template.
"""
run_name = format_name_template(run_name_template)
if run_name == "":
raise ValueError("Empty run names are not allowed.")
return run_name
is_setting_enabled(is_enabled_on_step, is_enabled_on_pipeline)
Checks if a certain setting is enabled within a step run.
This is the case if: - the setting is explicitly enabled for the step, or - the setting is neither explicitly disabled for the step nor the pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
is_enabled_on_step |
Optional[bool] |
The setting of the step. |
required |
is_enabled_on_pipeline |
Optional[bool] |
The setting of the pipeline. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the setting is enabled within the step run, False otherwise. |
Source code in zenml/orchestrators/utils.py
def is_setting_enabled(
is_enabled_on_step: Optional[bool],
is_enabled_on_pipeline: Optional[bool],
) -> bool:
"""Checks if a certain setting is enabled within a step run.
This is the case if:
- the setting is explicitly enabled for the step, or
- the setting is neither explicitly disabled for the step nor the pipeline.
Args:
is_enabled_on_step: The setting of the step.
is_enabled_on_pipeline: The setting of the pipeline.
Returns:
True if the setting is enabled within the step run, False otherwise.
"""
if is_enabled_on_step is not None:
return is_enabled_on_step
if is_enabled_on_pipeline is not None:
return is_enabled_on_pipeline
return True
wheeled_orchestrator
Wheeled orchestrator class.
WheeledOrchestrator (BaseOrchestrator, ABC)
Base class for wheeled orchestrators.
Source code in zenml/orchestrators/wheeled_orchestrator.py
class WheeledOrchestrator(BaseOrchestrator, ABC):
"""Base class for wheeled orchestrators."""
package_name = DEFAULT_PACKAGE_NAME
package_version = __version__
def copy_repository_to_temp_dir_and_add_setup_py(self) -> str:
"""Copy the repository to a temporary directory and add a setup.py file.
Returns:
Path to the temporary directory containing the copied repository.
"""
repo_path = get_source_root()
self.package_name = f"{DEFAULT_PACKAGE_NAME}_{self.sanitize_name(os.path.basename(repo_path))}"
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
# Create a folder within the temporary directory
temp_repo_path = os.path.join(temp_dir, self.package_name)
fileio.mkdir(temp_repo_path)
# Copy the repository to the temporary directory
copy_dir(repo_path, temp_repo_path)
# Create init file in the copied directory
init_file_path = os.path.join(temp_repo_path, "__init__.py")
with fileio.open(init_file_path, "w") as f:
f.write("")
# Create a setup.py file
setup_py_content = f"""
from setuptools import setup, find_packages
setup(
name="{self.package_name}",
version="{self.package_version}",
packages=find_packages(),
)
"""
setup_py_path = os.path.join(temp_dir, "setup.py")
with fileio.open(setup_py_path, "w") as f:
f.write(setup_py_content)
return temp_dir
def create_wheel(self, temp_dir: str) -> str:
"""Create a wheel for the package in the given temporary directory.
Args:
temp_dir (str): Path to the temporary directory containing the package.
Raises:
RuntimeError: If the wheel file could not be created.
Returns:
str: Path to the created wheel file.
"""
# Change to the temporary directory
original_dir = os.getcwd()
os.chdir(temp_dir)
try:
# Run the `pip wheel` command to create the wheel
result = subprocess.run(
["pip", "wheel", "."], check=True, capture_output=True
)
logger.debug(f"Wheel creation stdout: {result.stdout.decode()}")
logger.debug(f"Wheel creation stderr: {result.stderr.decode()}")
# Find the created wheel file
wheel_file = next(
(
file
for file in os.listdir(temp_dir)
if file.endswith(".whl")
),
None,
)
if wheel_file is None:
raise RuntimeError("Failed to create wheel file.")
wheel_path = os.path.join(temp_dir, wheel_file)
# Verify the wheel file is a valid zip file
import zipfile
if not zipfile.is_zipfile(wheel_path):
raise RuntimeError(
f"The file {wheel_path} is not a valid zip file."
)
return wheel_path
finally:
# Change back to the original directory
os.chdir(original_dir)
def sanitize_name(self, name: str) -> str:
"""Sanitize the value to be used in a cluster name.
Args:
name: Arbitrary input cluster name.
Returns:
Sanitized cluster name.
"""
name = re.sub(
r"[^a-z0-9-]", "-", name.lower()
) # replaces any character that is not a lowercase letter, digit, or hyphen with a hyphen
name = re.sub(r"^[-]+", "", name) # trim leading hyphens
name = re.sub(r"[-]+$", "", name) # trim trailing hyphens
return name
copy_repository_to_temp_dir_and_add_setup_py(self)
Copy the repository to a temporary directory and add a setup.py file.
Returns:
Type | Description |
---|---|
str |
Path to the temporary directory containing the copied repository. |
Source code in zenml/orchestrators/wheeled_orchestrator.py
def copy_repository_to_temp_dir_and_add_setup_py(self) -> str:
"""Copy the repository to a temporary directory and add a setup.py file.
Returns:
Path to the temporary directory containing the copied repository.
"""
repo_path = get_source_root()
self.package_name = f"{DEFAULT_PACKAGE_NAME}_{self.sanitize_name(os.path.basename(repo_path))}"
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
# Create a folder within the temporary directory
temp_repo_path = os.path.join(temp_dir, self.package_name)
fileio.mkdir(temp_repo_path)
# Copy the repository to the temporary directory
copy_dir(repo_path, temp_repo_path)
# Create init file in the copied directory
init_file_path = os.path.join(temp_repo_path, "__init__.py")
with fileio.open(init_file_path, "w") as f:
f.write("")
# Create a setup.py file
setup_py_content = f"""
from setuptools import setup, find_packages
setup(
name="{self.package_name}",
version="{self.package_version}",
packages=find_packages(),
)
"""
setup_py_path = os.path.join(temp_dir, "setup.py")
with fileio.open(setup_py_path, "w") as f:
f.write(setup_py_content)
return temp_dir
create_wheel(self, temp_dir)
Create a wheel for the package in the given temporary directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
temp_dir |
str |
Path to the temporary directory containing the package. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the wheel file could not be created. |
Returns:
Type | Description |
---|---|
str |
Path to the created wheel file. |
Source code in zenml/orchestrators/wheeled_orchestrator.py
def create_wheel(self, temp_dir: str) -> str:
"""Create a wheel for the package in the given temporary directory.
Args:
temp_dir (str): Path to the temporary directory containing the package.
Raises:
RuntimeError: If the wheel file could not be created.
Returns:
str: Path to the created wheel file.
"""
# Change to the temporary directory
original_dir = os.getcwd()
os.chdir(temp_dir)
try:
# Run the `pip wheel` command to create the wheel
result = subprocess.run(
["pip", "wheel", "."], check=True, capture_output=True
)
logger.debug(f"Wheel creation stdout: {result.stdout.decode()}")
logger.debug(f"Wheel creation stderr: {result.stderr.decode()}")
# Find the created wheel file
wheel_file = next(
(
file
for file in os.listdir(temp_dir)
if file.endswith(".whl")
),
None,
)
if wheel_file is None:
raise RuntimeError("Failed to create wheel file.")
wheel_path = os.path.join(temp_dir, wheel_file)
# Verify the wheel file is a valid zip file
import zipfile
if not zipfile.is_zipfile(wheel_path):
raise RuntimeError(
f"The file {wheel_path} is not a valid zip file."
)
return wheel_path
finally:
# Change back to the original directory
os.chdir(original_dir)
sanitize_name(self, name)
Sanitize the value to be used in a cluster name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
Arbitrary input cluster name. |
required |
Returns:
Type | Description |
---|---|
str |
Sanitized cluster name. |
Source code in zenml/orchestrators/wheeled_orchestrator.py
def sanitize_name(self, name: str) -> str:
"""Sanitize the value to be used in a cluster name.
Args:
name: Arbitrary input cluster name.
Returns:
Sanitized cluster name.
"""
name = re.sub(
r"[^a-z0-9-]", "-", name.lower()
) # replaces any character that is not a lowercase letter, digit, or hyphen with a hyphen
name = re.sub(r"^[-]+", "", name) # trim leading hyphens
name = re.sub(r"[-]+$", "", name) # trim trailing hyphens
return name