Skip to content

Dash

zenml.integrations.dash special

Initialization of the Dash integration.

DashIntegration (Integration)

Definition of Dash integration for ZenML.

Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
    """Definition of Dash integration for ZenML."""

    NAME = DASH
    REQUIREMENTS = [
        "dash>=2.0.0,!=2.9.0",
        "dash-cytoscape>=0.3.0",
        "dash-bootstrap-components>=1.0.1",
        "jupyter-dash>=0.4.2",
    ]

visualizers special

Initialization of the Pipeline Run Visualizer.

pipeline_run_lineage_visualizer

Implementation of the pipeline run lineage visualizer.

PipelineRunLineageVisualizer (BaseVisualizer)

Implementation of a lineage diagram via the dash and dash-cytoscape libraries.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BaseVisualizer):
    """Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""

    ARTIFACT_PREFIX = "artifact_"
    STEP_PREFIX = "step_"
    STATUS_CLASS_MAPPING = {
        ExecutionStatus.CACHED: "green",
        ExecutionStatus.FAILED: "red",
        ExecutionStatus.RUNNING: "yellow",
        ExecutionStatus.COMPLETED: "blue",
    }

    def visualize(
        self,
        object: PipelineRunView,
        magic: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> dash.Dash:
        """Method to visualize pipeline runs via the Dash library.

        The layout puts every layer of the dag in a column.

        Args:
            object: The pipeline run to visualize.
            magic: If True, the visualization is rendered in a magic mode.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            The Dash application.
        """
        external_stylesheets = [
            dbc.themes.BOOTSTRAP,
            dbc.icons.BOOTSTRAP,
        ]
        if magic:
            if Environment.in_notebook():
                # Only import jupyter_dash in this case
                from jupyter_dash import JupyterDash  # noqa

                JupyterDash.infer_jupyter_proxy_config()

                app = JupyterDash(
                    __name__,
                    external_stylesheets=external_stylesheets,
                )
                mode = "inline"
            else:
                cli_utils.warning(
                    "Cannot set magic flag in non-notebook environments."
                )
        else:
            app = dash.Dash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = None

        graph = LineageGraph()
        graph.generate_run_nodes_and_edges(object)
        first_step_id = graph.root_step_id

        # Parse lineage graph nodes
        nodes = []
        for node in graph.nodes:
            node_dict = node.dict()
            node_data = node_dict.pop("data")
            node_dict = {**node_dict, **node_data}
            node_dict["label"] = node_dict["name"]
            classes = self.STATUS_CLASS_MAPPING[node.data.status]
            if isinstance(node, ArtifactNode):
                classes = "rectangle " + classes
                node_dict["label"] += f" ({node_dict['artifact_data_type']})"
            dash_node = {"data": node_dict, "classes": classes}
            nodes.append(dash_node)

        # Parse lineage graph edges
        node_mapping = {node.id: node for node in graph.nodes}
        edges = []
        for edge in graph.edges:
            source_node = node_mapping[edge.source]
            if isinstance(source_node, StepNode):
                is_input_artifact = False
                step_node = node_mapping[edge.source]
                artifact_node = node_mapping[edge.target]
            else:
                is_input_artifact = True
                step_node = node_mapping[edge.target]
                artifact_node = node_mapping[edge.source]
            assert isinstance(artifact_node, ArtifactNode)
            artifact_is_cached = artifact_node.data.is_cached
            if is_input_artifact and artifact_is_cached:
                edge_status = self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]
            else:
                edge_status = self.STATUS_CLASS_MAPPING[step_node.data.status]
            edge_style = "dashed" if artifact_node.data.is_cached else "solid"
            edges.append(
                {
                    "data": edge.dict(),
                    "classes": f"edge-arrow {edge_status} {edge_style}",
                }
            )

        app.layout = dbc.Row(
            [
                dbc.Container(f"Run: {object.name}", class_name="h2"),
                *[
                    dbc.Container(f"- {k}: {v} ({type_})" + "\n\n")
                    for k, v, type_ in graph.run_metadata
                ],
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Row(
                                    [
                                        html.Span(
                                            [
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-circle-fill me-1"
                                                        ),
                                                        "Step",
                                                    ],
                                                    className="me-2",
                                                ),
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-square-fill me-1"
                                                        ),
                                                        "Artifact",
                                                    ],
                                                    className="me-4",
                                                ),
                                                dbc.Badge(
                                                    "Completed",
                                                    color=COLOR_BLUE,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Cached",
                                                    color=COLOR_GREEN,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Running",
                                                    color=COLOR_YELLOW,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Failed",
                                                    color=COLOR_RED,
                                                    className="me-1",
                                                ),
                                            ]
                                        ),
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        cyto.Cytoscape(
                                            id="cytoscape",
                                            layout={
                                                "name": "breadthfirst",
                                                "roots": f'[id = "{first_step_id}"]',
                                            },
                                            elements=edges + nodes,
                                            stylesheet=STYLESHEET,
                                            style={
                                                "width": "100%",
                                                "height": "800px",
                                            },
                                            zoom=1,
                                        )
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        dbc.Button(
                                            "Reset",
                                            id="bt-reset",
                                            color="primary",
                                            className="me-1",
                                        )
                                    ]
                                ),
                            ]
                        ),
                        dbc.Col(
                            [
                                dcc.Markdown(id="markdown-selected-node-data"),
                            ]
                        ),
                    ]
                ),
            ],
            className="p-5",
        )

        @app.callback(  # type: ignore[misc]
            Output("markdown-selected-node-data", "children"),
            Input("cytoscape", "selectedNodeData"),
        )
        def display_data(data_list: List[Dict[str, Any]]) -> str:
            """Callback for the text area below the graph.

            Args:
                data_list: The selected node data.

            Returns:
                str: The selected node data.
            """
            if data_list is None:
                return "Click on a node in the diagram."

            text = ""
            for data in data_list:
                if data["type"] == "artifact":
                    text += f"### Artifact '{data['name']}'" + "\n\n"
                    text += "#### Attributes:" + "\n\n"
                    for item in [
                        "execution_id",
                        "status",
                        "artifact_data_type",
                        "producer_step_id",
                        "parent_step_id",
                        "uri",
                    ]:
                        text += f"**{item}**: {data[item]}" + "\n\n"
                    if data["metadata"]:
                        text += "#### Metadata:" + "\n\n"
                        for k, v, type_ in data["metadata"]:
                            text += f"**{k}**: {v} ({type_})" + "\n\n"
                elif data["type"] == "step":
                    text += f"### Step '{data['name']}'" + "\n\n"
                    text += "#### Attributes:" + "\n\n"
                    for item in [
                        "execution_id",
                        "status",
                    ]:
                        text += f"**{item}**: {data[item]}" + "\n\n"
                    if data["inputs"]:
                        text += "#### Inputs:" + "\n\n"
                        for k, v in data["inputs"].items():
                            text += f"**{k}**: {v}" + "\n\n"
                    if data["outputs"]:
                        text += "#### Outputs:" + "\n\n"
                        for k, v in data["outputs"].items():
                            text += f"**{k}**: {v}" + "\n\n"
                    if data["parameters"]:
                        text += "#### Parameters:" + "\n\n"
                        for k, v in data["parameters"].items():
                            text += f"**{k}**: {v}" + "\n\n"
                    if data["configuration"]:
                        text += "#### Configuration:" + "\n\n"
                        for k, v in data["configuration"].items():
                            text += f"**{k}**: {v}" + "\n\n"
                    if data["metadata"]:
                        text += "#### Metadata:" + "\n\n"
                        for k, v, type_ in data["metadata"]:
                            text += f"**{k}**: {v} ({type_})" + "\n\n"
            return text

        @app.callback(  # type: ignore[misc]
            [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
            [Input("bt-reset", "n_clicks")],
        )
        def reset_layout(
            n_clicks: int,
        ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
            """Resets the layout.

            Args:
                n_clicks: The number of clicks on the reset button.

            Returns:
                The zoom and the elements.
            """
            logger.debug(n_clicks, "clicked in reset button.")
            return [1, edges + nodes]

        if mode is not None:
            app.run_server(mode=mode)
        app.run_server()
        return app
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize pipeline runs via the Dash library.

The layout puts every layer of the dag in a column.

Parameters:

Name Type Description Default
object PipelineRunView

The pipeline run to visualize.

required
magic bool

If True, the visualization is rendered in a magic mode.

False
*args Any

Additional positional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Dash

The Dash application.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
    self,
    object: PipelineRunView,
    magic: bool = False,
    *args: Any,
    **kwargs: Any,
) -> dash.Dash:
    """Method to visualize pipeline runs via the Dash library.

    The layout puts every layer of the dag in a column.

    Args:
        object: The pipeline run to visualize.
        magic: If True, the visualization is rendered in a magic mode.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        The Dash application.
    """
    external_stylesheets = [
        dbc.themes.BOOTSTRAP,
        dbc.icons.BOOTSTRAP,
    ]
    if magic:
        if Environment.in_notebook():
            # Only import jupyter_dash in this case
            from jupyter_dash import JupyterDash  # noqa

            JupyterDash.infer_jupyter_proxy_config()

            app = JupyterDash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = "inline"
        else:
            cli_utils.warning(
                "Cannot set magic flag in non-notebook environments."
            )
    else:
        app = dash.Dash(
            __name__,
            external_stylesheets=external_stylesheets,
        )
        mode = None

    graph = LineageGraph()
    graph.generate_run_nodes_and_edges(object)
    first_step_id = graph.root_step_id

    # Parse lineage graph nodes
    nodes = []
    for node in graph.nodes:
        node_dict = node.dict()
        node_data = node_dict.pop("data")
        node_dict = {**node_dict, **node_data}
        node_dict["label"] = node_dict["name"]
        classes = self.STATUS_CLASS_MAPPING[node.data.status]
        if isinstance(node, ArtifactNode):
            classes = "rectangle " + classes
            node_dict["label"] += f" ({node_dict['artifact_data_type']})"
        dash_node = {"data": node_dict, "classes": classes}
        nodes.append(dash_node)

    # Parse lineage graph edges
    node_mapping = {node.id: node for node in graph.nodes}
    edges = []
    for edge in graph.edges:
        source_node = node_mapping[edge.source]
        if isinstance(source_node, StepNode):
            is_input_artifact = False
            step_node = node_mapping[edge.source]
            artifact_node = node_mapping[edge.target]
        else:
            is_input_artifact = True
            step_node = node_mapping[edge.target]
            artifact_node = node_mapping[edge.source]
        assert isinstance(artifact_node, ArtifactNode)
        artifact_is_cached = artifact_node.data.is_cached
        if is_input_artifact and artifact_is_cached:
            edge_status = self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]
        else:
            edge_status = self.STATUS_CLASS_MAPPING[step_node.data.status]
        edge_style = "dashed" if artifact_node.data.is_cached else "solid"
        edges.append(
            {
                "data": edge.dict(),
                "classes": f"edge-arrow {edge_status} {edge_style}",
            }
        )

    app.layout = dbc.Row(
        [
            dbc.Container(f"Run: {object.name}", class_name="h2"),
            *[
                dbc.Container(f"- {k}: {v} ({type_})" + "\n\n")
                for k, v, type_ in graph.run_metadata
            ],
            dbc.Row(
                [
                    dbc.Col(
                        [
                            dbc.Row(
                                [
                                    html.Span(
                                        [
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-circle-fill me-1"
                                                    ),
                                                    "Step",
                                                ],
                                                className="me-2",
                                            ),
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-square-fill me-1"
                                                    ),
                                                    "Artifact",
                                                ],
                                                className="me-4",
                                            ),
                                            dbc.Badge(
                                                "Completed",
                                                color=COLOR_BLUE,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Cached",
                                                color=COLOR_GREEN,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Running",
                                                color=COLOR_YELLOW,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Failed",
                                                color=COLOR_RED,
                                                className="me-1",
                                            ),
                                        ]
                                    ),
                                ]
                            ),
                            dbc.Row(
                                [
                                    cyto.Cytoscape(
                                        id="cytoscape",
                                        layout={
                                            "name": "breadthfirst",
                                            "roots": f'[id = "{first_step_id}"]',
                                        },
                                        elements=edges + nodes,
                                        stylesheet=STYLESHEET,
                                        style={
                                            "width": "100%",
                                            "height": "800px",
                                        },
                                        zoom=1,
                                    )
                                ]
                            ),
                            dbc.Row(
                                [
                                    dbc.Button(
                                        "Reset",
                                        id="bt-reset",
                                        color="primary",
                                        className="me-1",
                                    )
                                ]
                            ),
                        ]
                    ),
                    dbc.Col(
                        [
                            dcc.Markdown(id="markdown-selected-node-data"),
                        ]
                    ),
                ]
            ),
        ],
        className="p-5",
    )

    @app.callback(  # type: ignore[misc]
        Output("markdown-selected-node-data", "children"),
        Input("cytoscape", "selectedNodeData"),
    )
    def display_data(data_list: List[Dict[str, Any]]) -> str:
        """Callback for the text area below the graph.

        Args:
            data_list: The selected node data.

        Returns:
            str: The selected node data.
        """
        if data_list is None:
            return "Click on a node in the diagram."

        text = ""
        for data in data_list:
            if data["type"] == "artifact":
                text += f"### Artifact '{data['name']}'" + "\n\n"
                text += "#### Attributes:" + "\n\n"
                for item in [
                    "execution_id",
                    "status",
                    "artifact_data_type",
                    "producer_step_id",
                    "parent_step_id",
                    "uri",
                ]:
                    text += f"**{item}**: {data[item]}" + "\n\n"
                if data["metadata"]:
                    text += "#### Metadata:" + "\n\n"
                    for k, v, type_ in data["metadata"]:
                        text += f"**{k}**: {v} ({type_})" + "\n\n"
            elif data["type"] == "step":
                text += f"### Step '{data['name']}'" + "\n\n"
                text += "#### Attributes:" + "\n\n"
                for item in [
                    "execution_id",
                    "status",
                ]:
                    text += f"**{item}**: {data[item]}" + "\n\n"
                if data["inputs"]:
                    text += "#### Inputs:" + "\n\n"
                    for k, v in data["inputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                if data["outputs"]:
                    text += "#### Outputs:" + "\n\n"
                    for k, v in data["outputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                if data["parameters"]:
                    text += "#### Parameters:" + "\n\n"
                    for k, v in data["parameters"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                if data["configuration"]:
                    text += "#### Configuration:" + "\n\n"
                    for k, v in data["configuration"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                if data["metadata"]:
                    text += "#### Metadata:" + "\n\n"
                    for k, v, type_ in data["metadata"]:
                        text += f"**{k}**: {v} ({type_})" + "\n\n"
        return text

    @app.callback(  # type: ignore[misc]
        [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
        [Input("bt-reset", "n_clicks")],
    )
    def reset_layout(
        n_clicks: int,
    ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
        """Resets the layout.

        Args:
            n_clicks: The number of clicks on the reset button.

        Returns:
            The zoom and the elements.
        """
        logger.debug(n_clicks, "clicked in reset button.")
        return [1, edges + nodes]

    if mode is not None:
        app.run_server(mode=mode)
    app.run_server()
    return app