Airflow
zenml.integrations.airflow
special
Airflow integration for ZenML.
The Airflow integration powers an alternative orchestrator.
AirflowIntegration (Integration)
Definition of Airflow Integration for ZenML.
Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
"""Definition of Airflow Integration for ZenML."""
NAME = AIRFLOW
REQUIREMENTS = []
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.airflow.flavors import (
AirflowOrchestratorFlavor,
)
return [AirflowOrchestratorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Airflow integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.airflow.flavors import (
AirflowOrchestratorFlavor,
)
return [AirflowOrchestratorFlavor]
flavors
special
Airflow integration flavors.
airflow_orchestrator_flavor
Airflow orchestrator flavor.
AirflowOrchestratorConfig (BaseOrchestratorConfig, AirflowOrchestratorSettings)
Configuration for the Airflow orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
local |
bool |
If the orchestrator is local or not. If this is True, will spin up a local Airflow server to run pipelines. |
Source code in zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py
class AirflowOrchestratorConfig(
BaseOrchestratorConfig, AirflowOrchestratorSettings
):
"""Configuration for the Airflow orchestrator.
Attributes:
local: If the orchestrator is local or not. If this is True, will spin
up a local Airflow server to run pipelines.
"""
local: bool = True
@property
def is_schedulable(self) -> bool:
"""Whether the orchestrator is schedulable or not.
Returns:
Whether the orchestrator is schedulable or not.
"""
return True
is_schedulable: bool
property
readonly
Whether the orchestrator is schedulable or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator is schedulable or not. |
AirflowOrchestratorFlavor (BaseOrchestratorFlavor)
Flavor for the Airflow orchestrator.
Source code in zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py
class AirflowOrchestratorFlavor(BaseOrchestratorFlavor):
"""Flavor for the Airflow orchestrator."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return AIRFLOW_ORCHESTRATOR_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/airflow.png"
@property
def config_class(self) -> Type[AirflowOrchestratorConfig]:
"""Returns `AirflowOrchestratorConfig` config class.
Returns:
The config class.
"""
return AirflowOrchestratorConfig
@property
def implementation_class(self) -> Type["AirflowOrchestrator"]:
"""Implementation class.
Returns:
The implementation class.
"""
from zenml.integrations.airflow.orchestrators import (
AirflowOrchestrator,
)
return AirflowOrchestrator
config_class: Type[zenml.integrations.airflow.flavors.airflow_orchestrator_flavor.AirflowOrchestratorConfig]
property
readonly
Returns AirflowOrchestratorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.airflow.flavors.airflow_orchestrator_flavor.AirflowOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[AirflowOrchestrator]
property
readonly
Implementation class.
Returns:
Type | Description |
---|---|
Type[AirflowOrchestrator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
AirflowOrchestratorSettings (BaseSettings)
Settings for the Airflow orchestrator.
Attributes:
Name | Type | Description |
---|---|---|
dag_output_dir |
Optional[str] |
Output directory in which to write the Airflow DAG. |
dag_id |
Optional[str] |
Optional ID of the Airflow DAG to create. This value is only applied if the settings are defined on a ZenML pipeline and ignored if defined on a step. |
dag_tags |
List[str] |
Tags to add to the Airflow DAG. This value is only applied if the settings are defined on a ZenML pipeline and ignored if defined on a step. |
dag_args |
Dict[str, Any] |
Arguments for initializing the Airflow DAG. This value is only applied if the settings are defined on a ZenML pipeline and ignored if defined on a step. |
operator |
str |
The operator to use for one or all steps. This can either be
a |
operator_args |
Dict[str, Any] |
Arguments for initializing the Airflow operator. |
custom_dag_generator |
Optional[str] |
Source string of a module to use for generating
Airflow DAGs. This module must contain the same classes and
constants as the
|
Source code in zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py
class AirflowOrchestratorSettings(BaseSettings):
"""Settings for the Airflow orchestrator.
Attributes:
dag_output_dir: Output directory in which to write the Airflow DAG.
dag_id: Optional ID of the Airflow DAG to create. This value is only
applied if the settings are defined on a ZenML pipeline and
ignored if defined on a step.
dag_tags: Tags to add to the Airflow DAG. This value is only
applied if the settings are defined on a ZenML pipeline and
ignored if defined on a step.
dag_args: Arguments for initializing the Airflow DAG. This
value is only applied if the settings are defined on a ZenML
pipeline and ignored if defined on a step.
operator: The operator to use for one or all steps. This can either be
a `zenml.integrations.airflow.flavors.airflow_orchestrator_flavor.OperatorType`
or a string representing the source of the operator class to use
(e.g. `airflow.providers.docker.operators.docker.DockerOperator`)
operator_args: Arguments for initializing the Airflow
operator.
custom_dag_generator: Source string of a module to use for generating
Airflow DAGs. This module must contain the same classes and
constants as the
`zenml.integrations.airflow.orchestrators.dag_generator` module.
This value is only applied if the settings are defined on a ZenML
pipeline and ignored if defined on a step.
"""
dag_output_dir: Optional[str] = None
dag_id: Optional[str] = None
dag_tags: List[str] = []
dag_args: Dict[str, Any] = {}
operator: str = OperatorType.DOCKER.source
operator_args: Dict[str, Any] = {}
custom_dag_generator: Optional[str] = None
@field_validator("operator", mode="before")
@classmethod
def _convert_operator(cls, value: Any) -> Any:
"""Converts operator types to source strings.
Args:
value: The operator type value.
Returns:
The operator source.
"""
if isinstance(value, OperatorType):
return value.source
try:
return OperatorType(value).source
except ValueError:
return value
OperatorType (Enum)
Airflow operator types.
Source code in zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py
class OperatorType(Enum):
"""Airflow operator types."""
DOCKER = "docker"
KUBERNETES_POD = "kubernetes_pod"
GKE_START_POD = "gke_start_pod"
@property
def source(self) -> str:
"""Operator source.
Returns:
The operator source.
"""
return {
OperatorType.DOCKER: "airflow.providers.docker.operators.docker.DockerOperator",
OperatorType.KUBERNETES_POD: "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator",
OperatorType.GKE_START_POD: "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator",
}[self]
orchestrators
special
The Airflow integration enables the use of Airflow as a pipeline orchestrator.
airflow_orchestrator
Implementation of Airflow orchestrator integration.
AirflowOrchestrator (ContainerizedOrchestrator)
Orchestrator responsible for running pipelines using Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(ContainerizedOrchestrator):
"""Orchestrator responsible for running pipelines using Airflow."""
def __init__(self, **values: Any):
"""Initialize the orchestrator.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self.dags_directory = os.path.join(
io_utils.get_global_config_directory(),
"airflow",
str(self.id),
"dags",
)
@property
def config(self) -> AirflowOrchestratorConfig:
"""Returns the orchestrator config.
Returns:
The configuration.
"""
return cast(AirflowOrchestratorConfig, self._config)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Kubeflow orchestrator.
Returns:
The settings class.
"""
return AirflowOrchestratorSettings
@property
def validator(self) -> Optional["StackValidator"]:
"""Validates the stack.
In the remote case, checks that the stack contains a container registry
and only remote components.
Returns:
A `StackValidator` instance.
"""
if self.config.local:
# No container registry required if just running locally.
return None
else:
def _validate_remote_components(
stack: "Stack",
) -> Tuple[bool, str]:
for component in stack.components.values():
if not component.config.is_local:
continue
return False, (
f"The Airflow orchestrator is configured to run "
f"pipelines remotely, but the '{component.name}' "
f"{component.type.value} is a local stack component "
f"and will not be available in the Airflow "
f"task.\nPlease ensure that you always use non-local "
f"stack components with a remote Airflow orchestrator, "
f"otherwise you may run into pipeline execution "
f"problems."
)
return True, ""
return StackValidator(
required_components={
StackComponentType.CONTAINER_REGISTRY,
StackComponentType.IMAGE_BUILDER,
},
custom_validation_function=_validate_remote_components,
)
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
) -> None:
"""Builds a Docker image to run pipeline steps.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
if self.config.local:
stack.check_local_paths()
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Creates and writes an Airflow DAG zip file.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
"""
pipeline_settings = cast(
AirflowOrchestratorSettings, self.get_settings(deployment)
)
dag_generator_values = get_dag_generator_values(
custom_dag_generator_source=pipeline_settings.custom_dag_generator
)
command = StepEntrypointConfiguration.get_entrypoint_command()
tasks = []
for step_name, step in deployment.step_configurations.items():
settings = cast(
AirflowOrchestratorSettings, self.get_settings(step)
)
image = self.get_image(deployment=deployment, step_name=step_name)
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
operator_args = settings.operator_args.copy()
if self.requires_resources_in_orchestration_environment(step=step):
if settings.operator == OperatorType.KUBERNETES_POD.source:
self._apply_resource_settings(
resource_settings=step.config.resource_settings,
operator_args=operator_args,
)
else:
logger.warning(
"Specifying step resources is only supported when "
"using KubernetesPodOperators, ignoring resource "
"configuration for step %s.",
step_name,
)
task = dag_generator_values.task_configuration_class(
id=step_name,
zenml_step_name=step_name,
upstream_steps=step.spec.upstream_steps,
docker_image=image,
command=command,
arguments=arguments,
environment=environment,
operator_source=settings.operator,
operator_args=operator_args,
)
tasks.append(task)
local_stores_path = (
os.path.expanduser(GlobalConfiguration().local_stores_path)
if self.config.local
else None
)
dag_id = pipeline_settings.dag_id or get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
dag_config = dag_generator_values.dag_configuration_class(
id=dag_id,
local_stores_path=local_stores_path,
tasks=tasks,
tags=pipeline_settings.dag_tags,
dag_args=pipeline_settings.dag_args,
**self._translate_schedule(deployment.schedule),
)
self._write_dag(
dag_config,
dag_generator_values=dag_generator_values,
output_dir=pipeline_settings.dag_output_dir or self.dags_directory,
)
def _apply_resource_settings(
self,
resource_settings: "ResourceSettings",
operator_args: Dict[str, Any],
) -> None:
"""Adds resource settings to the operator args.
Args:
resource_settings: The resource settings to add.
operator_args: The operator args which will get modified in-place.
"""
if "container_resources" in operator_args:
logger.warning(
"Received duplicate resources from ResourceSettings: `%s`"
"and operator_args: `%s`. Ignoring the resources defined by "
"the ResourceSettings.",
resource_settings,
operator_args["container_resources"],
)
else:
limits = {}
if resource_settings.cpu_count is not None:
limits["cpu"] = str(resource_settings.cpu_count)
if resource_settings.memory is not None:
memory_limit = resource_settings.memory[:-1]
limits["memory"] = memory_limit
if resource_settings.gpu_count is not None:
logger.warning(
"Specifying GPU resources is not supported for the Airflow "
"orchestrator."
)
operator_args["container_resources"] = {"limits": limits}
def _write_dag(
self,
dag_config: "DagConfiguration",
dag_generator_values: DagGeneratorValues,
output_dir: str,
) -> None:
"""Writes an Airflow DAG to disk.
Args:
dag_config: Configuration of the DAG to write.
dag_generator_values: Values of the DAG generator to use.
output_dir: The directory in which to write the DAG.
"""
io_utils.create_dir_recursive_if_not_exists(output_dir)
if self.config.local and output_dir == self.dags_directory:
logger.warning(
"You're using a local Airflow orchestrator but have not "
"specified a custom DAG output directory. Unless you've "
"configured your Airflow server to look for DAGs in this "
"directory (%s), this DAG will not be found automatically "
"by your local Airflow server.",
output_dir,
)
def _write_zip(path: str) -> None:
with zipfile.ZipFile(path, mode="w") as z:
z.write(dag_generator_values.file, arcname="dag.py")
z.writestr(
dag_generator_values.config_file_name,
dag_config.model_dump_json(),
)
logger.info("Writing DAG definition to `%s`.", path)
dag_filename = f"{dag_config.id}.zip"
if io_utils.is_remote(output_dir):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
local_zip_path = os.path.join(self.dags_directory, dag_filename)
remote_zip_path = os.path.join(output_dir, dag_filename)
_write_zip(local_zip_path)
try:
fileio.copy(local_zip_path, remote_zip_path)
logger.info("Copied DAG definition to `%s`.", remote_zip_path)
except Exception as e:
logger.exception(e)
logger.error(
"Failed to upload DAG to remote path `%s`. To run the "
"pipeline in Airflow, please manually copy the file `%s` "
"to your Airflow DAG directory.",
remote_zip_path,
local_zip_path,
)
else:
zip_path = os.path.join(output_dir, dag_filename)
_write_zip(zip_path)
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If the environment variable specifying the run id
is not set.
Returns:
The orchestrator run id.
"""
from zenml.integrations.airflow.orchestrators.dag_generator import (
ENV_ZENML_AIRFLOW_RUN_ID,
)
try:
return os.environ[ENV_ZENML_AIRFLOW_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_AIRFLOW_RUN_ID}."
)
@staticmethod
def _translate_schedule(
schedule: Optional["Schedule"] = None,
) -> Dict[str, Any]:
"""Convert ZenML schedule into Airflow schedule.
The Airflow schedule uses slightly different naming and needs some
default entries for execution without a schedule.
Args:
schedule: Containing the interval, start and end date and
a boolean flag that defines if past runs should be caught up
on
Returns:
Airflow configuration dict.
"""
if schedule:
if schedule.cron_expression:
start_time = schedule.start_time or (
datetime.datetime.utcnow() - datetime.timedelta(7)
)
return {
"schedule": schedule.cron_expression,
"start_date": start_time,
"end_date": schedule.end_time,
"catchup": schedule.catchup,
}
else:
return {
"schedule": schedule.interval_second,
"start_date": schedule.start_time,
"end_date": schedule.end_time,
"catchup": schedule.catchup,
}
return {
"schedule": "@once",
# set a start time in the past and disable catchup so airflow
# runs the dag immediately
"start_date": datetime.datetime.utcnow() - datetime.timedelta(7),
"catchup": False,
}
config: AirflowOrchestratorConfig
property
readonly
Returns the orchestrator config.
Returns:
Type | Description |
---|---|
AirflowOrchestratorConfig |
The configuration. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Kubeflow orchestrator.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[StackValidator]
property
readonly
Validates the stack.
In the remote case, checks that the stack contains a container registry and only remote components.
Returns:
Type | Description |
---|---|
Optional[StackValidator] |
A |
__init__(self, **values)
special
Initialize the orchestrator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**values |
Any |
Values to set in the orchestrator. |
{} |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
"""Initialize the orchestrator.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self.dags_directory = os.path.join(
io_utils.get_global_config_directory(),
"airflow",
str(self.id),
"dags",
)
get_orchestrator_run_id(self)
Returns the active orchestrator run id.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the environment variable specifying the run id is not set. |
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If the environment variable specifying the run id
is not set.
Returns:
The orchestrator run id.
"""
from zenml.integrations.airflow.orchestrators.dag_generator import (
ENV_ZENML_AIRFLOW_RUN_ID,
)
try:
return os.environ[ENV_ZENML_AIRFLOW_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_AIRFLOW_RUN_ID}."
)
prepare_or_run_pipeline(self, deployment, stack, environment)
Creates and writes an Airflow DAG zip file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Creates and writes an Airflow DAG zip file.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
"""
pipeline_settings = cast(
AirflowOrchestratorSettings, self.get_settings(deployment)
)
dag_generator_values = get_dag_generator_values(
custom_dag_generator_source=pipeline_settings.custom_dag_generator
)
command = StepEntrypointConfiguration.get_entrypoint_command()
tasks = []
for step_name, step in deployment.step_configurations.items():
settings = cast(
AirflowOrchestratorSettings, self.get_settings(step)
)
image = self.get_image(deployment=deployment, step_name=step_name)
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
operator_args = settings.operator_args.copy()
if self.requires_resources_in_orchestration_environment(step=step):
if settings.operator == OperatorType.KUBERNETES_POD.source:
self._apply_resource_settings(
resource_settings=step.config.resource_settings,
operator_args=operator_args,
)
else:
logger.warning(
"Specifying step resources is only supported when "
"using KubernetesPodOperators, ignoring resource "
"configuration for step %s.",
step_name,
)
task = dag_generator_values.task_configuration_class(
id=step_name,
zenml_step_name=step_name,
upstream_steps=step.spec.upstream_steps,
docker_image=image,
command=command,
arguments=arguments,
environment=environment,
operator_source=settings.operator,
operator_args=operator_args,
)
tasks.append(task)
local_stores_path = (
os.path.expanduser(GlobalConfiguration().local_stores_path)
if self.config.local
else None
)
dag_id = pipeline_settings.dag_id or get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
dag_config = dag_generator_values.dag_configuration_class(
id=dag_id,
local_stores_path=local_stores_path,
tasks=tasks,
tags=pipeline_settings.dag_tags,
dag_args=pipeline_settings.dag_args,
**self._translate_schedule(deployment.schedule),
)
self._write_dag(
dag_config,
dag_generator_values=dag_generator_values,
output_dir=pipeline_settings.dag_output_dir or self.dags_directory,
)
prepare_pipeline_deployment(self, deployment, stack)
Builds a Docker image to run pipeline steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment configuration. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
) -> None:
"""Builds a Docker image to run pipeline steps.
Args:
deployment: The pipeline deployment configuration.
stack: The stack on which the pipeline will be deployed.
"""
if self.config.local:
stack.check_local_paths()
DagGeneratorValues (tuple)
Values from the DAG generator module.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class DagGeneratorValues(NamedTuple):
"""Values from the DAG generator module."""
file: str
config_file_name: str
run_id_env_variable_name: str
dag_configuration_class: Type["DagConfiguration"]
task_configuration_class: Type["TaskConfiguration"]
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, file, config_file_name, run_id_env_variable_name, dag_configuration_class, task_configuration_class)
special
staticmethod
Create new instance of DagGeneratorValues(file, config_file_name, run_id_env_variable_name, dag_configuration_class, task_configuration_class)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
get_dag_generator_values(custom_dag_generator_source=None)
Gets values from the DAG generator module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
custom_dag_generator_source |
Optional[str] |
Source of a custom DAG generator module. |
None |
Returns:
Type | Description |
---|---|
DagGeneratorValues |
DAG generator module values. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def get_dag_generator_values(
custom_dag_generator_source: Optional[str] = None,
) -> DagGeneratorValues:
"""Gets values from the DAG generator module.
Args:
custom_dag_generator_source: Source of a custom DAG generator module.
Returns:
DAG generator module values.
"""
if custom_dag_generator_source:
module = importlib.import_module(custom_dag_generator_source)
else:
from zenml.integrations.airflow.orchestrators import dag_generator
module = dag_generator
assert module.__file__
return DagGeneratorValues(
file=module.__file__,
config_file_name=module.CONFIG_FILENAME,
run_id_env_variable_name=module.ENV_ZENML_AIRFLOW_RUN_ID,
dag_configuration_class=module.DagConfiguration,
task_configuration_class=module.TaskConfiguration,
)
dag_generator
Module to generate an Airflow DAG from a config file.
DagConfiguration (BaseModel)
Airflow DAG configuration.
Source code in zenml/integrations/airflow/orchestrators/dag_generator.py
class DagConfiguration(BaseModel):
"""Airflow DAG configuration."""
id: str
tasks: List[TaskConfiguration]
local_stores_path: Optional[str] = None
schedule: Union[datetime.timedelta, str] = Field(
union_mode="left_to_right"
)
start_date: datetime.datetime
end_date: Optional[datetime.datetime] = None
catchup: bool = False
tags: List[str] = []
dag_args: Dict[str, Any] = {}
TaskConfiguration (BaseModel)
Airflow task configuration.
Source code in zenml/integrations/airflow/orchestrators/dag_generator.py
class TaskConfiguration(BaseModel):
"""Airflow task configuration."""
id: str
zenml_step_name: str
upstream_steps: List[str]
docker_image: str
command: List[str]
arguments: List[str]
environment: Dict[str, str] = {}
operator_source: str
operator_args: Dict[str, Any] = {}
get_docker_operator_init_kwargs(dag_config, task_config)
Gets keyword arguments to pass to the DockerOperator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag_config |
DagConfiguration |
The configuration of the DAG. |
required |
task_config |
TaskConfiguration |
The configuration of the task. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The init keyword arguments. |
Source code in zenml/integrations/airflow/orchestrators/dag_generator.py
def get_docker_operator_init_kwargs(
dag_config: DagConfiguration, task_config: TaskConfiguration
) -> Dict[str, Any]:
"""Gets keyword arguments to pass to the DockerOperator.
Args:
dag_config: The configuration of the DAG.
task_config: The configuration of the task.
Returns:
The init keyword arguments.
"""
mounts = []
extra_hosts = {}
environment = task_config.environment
environment[ENV_ZENML_AIRFLOW_RUN_ID] = "{{run_id}}"
if dag_config.local_stores_path:
from docker.types import Mount
environment[ENV_ZENML_LOCAL_STORES_PATH] = dag_config.local_stores_path
mounts = [
Mount(
target=dag_config.local_stores_path,
source=dag_config.local_stores_path,
type="bind",
)
]
extra_hosts = {"host.docker.internal": "host-gateway"}
return {
"image": task_config.docker_image,
"command": task_config.command + task_config.arguments,
"mounts": mounts,
"environment": environment,
"extra_hosts": extra_hosts,
}
get_kubernetes_pod_operator_init_kwargs(dag_config, task_config)
Gets keyword arguments to pass to the KubernetesPodOperator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dag_config |
DagConfiguration |
The configuration of the DAG. |
required |
task_config |
TaskConfiguration |
The configuration of the task. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The init keyword arguments. |
Source code in zenml/integrations/airflow/orchestrators/dag_generator.py
def get_kubernetes_pod_operator_init_kwargs(
dag_config: DagConfiguration, task_config: TaskConfiguration
) -> Dict[str, Any]:
"""Gets keyword arguments to pass to the KubernetesPodOperator.
Args:
dag_config: The configuration of the DAG.
task_config: The configuration of the task.
Returns:
The init keyword arguments.
"""
from kubernetes.client.models import V1EnvVar
environment = task_config.environment
environment[ENV_ZENML_AIRFLOW_RUN_ID] = "{{run_id}}"
return {
"name": f"{dag_config.id}_{task_config.id}",
"namespace": "default",
"image": task_config.docker_image,
"cmds": task_config.command,
"arguments": task_config.arguments,
"env_vars": [
V1EnvVar(name=key, value=value)
for key, value in environment.items()
],
}
get_operator_init_kwargs(operator_class, dag_config, task_config)
Gets keyword arguments to pass to the operator init method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
operator_class |
Type[Any] |
The operator class for which to get the kwargs. |
required |
dag_config |
DagConfiguration |
The configuration of the DAG. |
required |
task_config |
TaskConfiguration |
The configuration of the task. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The init keyword arguments. |
Source code in zenml/integrations/airflow/orchestrators/dag_generator.py
def get_operator_init_kwargs(
operator_class: Type[Any],
dag_config: DagConfiguration,
task_config: TaskConfiguration,
) -> Dict[str, Any]:
"""Gets keyword arguments to pass to the operator init method.
Args:
operator_class: The operator class for which to get the kwargs.
dag_config: The configuration of the DAG.
task_config: The configuration of the task.
Returns:
The init keyword arguments.
"""
init_kwargs = {"task_id": task_config.id}
try:
from airflow.providers.docker.operators.docker import DockerOperator
if issubclass(operator_class, DockerOperator):
init_kwargs.update(
get_docker_operator_init_kwargs(
dag_config=dag_config, task_config=task_config
)
)
except ImportError:
pass
try:
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
)
if issubclass(operator_class, KubernetesPodOperator):
init_kwargs.update(
get_kubernetes_pod_operator_init_kwargs(
dag_config=dag_config, task_config=task_config
)
)
except ImportError:
pass
init_kwargs.update(task_config.operator_args)
return init_kwargs
import_class_by_path(class_path)
Imports a class based on a given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_path |
str |
str, class_source e.g. this.module.Class |
required |
Returns:
Type | Description |
---|---|
Type[Any] |
the given class |
Source code in zenml/integrations/airflow/orchestrators/dag_generator.py
def import_class_by_path(class_path: str) -> Type[Any]:
"""Imports a class based on a given path.
Args:
class_path: str, class_source e.g. this.module.Class
Returns:
the given class
"""
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name) # type: ignore[no-any-return]