Skip to content

Tensorboard

zenml.integrations.tensorboard

Initialization for TensorBoard integration.

Attributes

TENSORBOARD = 'tensorboard' module-attribute

Classes

Integration

Base class for integration in ZenML.

Functions
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
170
171
172
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for r in cls.get_requirements():
        try:
            # First check if the base package is installed
            dist = pkg_resources.get_distribution(r)

            # Next, check if the dependencies (including extras) are
            # installed
            deps: List[Requirement] = []

            _, extras = parse_requirement(r)
            if extras:
                extra_list = extras[1:-1].split(",")
                for extra in extra_list:
                    try:
                        requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                    except pkg_resources.UnknownExtra as e:
                        logger.debug(f"Unknown extra: {str(e)}")
                        return False
                    deps.extend(requirements)
            else:
                deps = dist.requires()

            for ri in deps:
                try:
                    # Remove the "extra == ..." part from the requirement string
                    cleaned_req = re.sub(
                        r"; extra == \"\w+\"", "", str(ri)
                    )
                    pkg_resources.get_distribution(cleaned_req)
                except pkg_resources.DistributionNotFound as e:
                    logger.debug(
                        f"Unable to find required dependency "
                        f"'{e.req}' for requirement '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False
                except pkg_resources.VersionConflict as e:
                    logger.debug(
                        f"Package version '{e.dist}' does not match "
                        f"version '{e.req}' required by '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False

        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"Package version '{e.dist}' does not match version "
                f"'{e.req}' necessary for integration {cls.NAME}."
            )
            return False

    logger.debug(
        f"Integration {cls.NAME} is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
174
175
176
177
178
179
180
181
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
135
136
137
138
139
140
141
142
143
144
145
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret
plugin_flavors() -> List[Type[BasePluginFlavor]] classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in src/zenml/integrations/integration.py
183
184
185
186
187
188
189
190
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

TensorBoardIntegration

Bases: Integration

Definition of TensorBoard integration for ZenML.

Functions
activate() -> None classmethod

Activates the integration.

Source code in src/zenml/integrations/tensorboard/__init__.py
40
41
42
43
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.tensorboard import services  # noqa
get_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Defines platform specific requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/tensorboard/__init__.py
27
28
29
30
31
32
33
34
35
36
37
38
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
    """Defines platform specific requirements for the integration.

    Args:
        target_os: The target operating system.

    Returns:
        A list of requirements.
    """
    requirements = ["tensorboard>=2.12,<2.15"]
    return requirements

Modules

services

Initialization for TensorBoard services.

Classes
Modules
tensorboard_service

Implementation of the TensorBoard service.

Classes
TensorboardService(config: Union[TensorboardServiceConfig, Dict[str, Any]], **attrs: Any)

Bases: LocalDaemonService

TensorBoard service.

This can be used to start a local TensorBoard server for one or more models.

Attributes:

Name Type Description
SERVICE_TYPE

a service type descriptor with information describing the TensorBoard service class

config TensorboardServiceConfig

service configuration

endpoint LocalDaemonServiceEndpoint

optional service endpoint

Initialization for TensorBoard service.

Parameters:

Name Type Description Default
config Union[TensorboardServiceConfig, Dict[str, Any]]

service configuration

required
**attrs Any

additional attributes

{}
Source code in src/zenml/integrations/tensorboard/services/tensorboard_service.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def __init__(
    self,
    config: Union[TensorboardServiceConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialization for TensorBoard service.

    Args:
        config: service configuration
        **attrs: additional attributes
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-697]: implement a service factory or builder for TensorBoard
    #   deployment services
    if (
        isinstance(config, TensorboardServiceConfig)
        and "endpoint" not in attrs
    ):
        endpoint = LocalDaemonServiceEndpoint(
            config=LocalDaemonServiceEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path="",
                    use_head_request=True,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    if "uuid" not in attrs:
        attrs["uuid"] = uuid.uuid4()
    super().__init__(config=config, **attrs)
Functions
run() -> None

Initialize and run the TensorBoard server.

Source code in src/zenml/integrations/tensorboard/services/tensorboard_service.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def run(self) -> None:
    """Initialize and run the TensorBoard server."""
    logger.info(
        "Starting TensorBoard service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        tensorboard = program.TensorBoard(
            plugins=default.get_plugins(),
        )
        tensorboard.configure(
            logdir=self.config.logdir,
            port=self.endpoint.status.port,
            host="localhost",
            max_reload_threads=self.config.max_reload_threads,
            reload_interval=self.config.reload_interval,
        )
        tensorboard.main()
    except KeyboardInterrupt:
        logger.info(
            "TensorBoard service stopped. Resuming normal execution."
        )
TensorboardServiceConfig(**data: Any)

Bases: LocalDaemonServiceConfig

TensorBoard service configuration.

Attributes:

Name Type Description
logdir str

location of TensorBoard log files.

max_reload_threads int

the max number of threads that TensorBoard can use to reload runs. Each thread reloads one run at a time.

reload_interval int

how often the backend should load more data, in seconds. Set to 0 to load just once at startup.

Source code in src/zenml/services/service.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __init__(self, **data: Any):
    """Initialize the service configuration.

    Args:
        **data: keyword arguments.

    Raises:
        ValueError: if neither 'name' nor 'model_name' is set.
    """
    super().__init__(**data)
    if self.name or self.model_name:
        self.service_name = data.get(
            "service_name",
            f"{ZENM_ENDPOINT_PREFIX}{self.name or self.model_name}",
        )
    else:
        raise ValueError("Either 'name' or 'model_name' must be set.")
Functions

visualizers

Initialization for TensorBoard visualizer.

Classes
Functions
Modules
tensorboard_visualizer

Implementation of a TensorBoard visualizer step.

Classes
TensorboardVisualizer

The implementation of a TensorBoard Visualizer.

Functions
find_running_tensorboard_server(logdir: str) -> Optional[TensorBoardInfo] classmethod

Find a local TensorBoard server instance.

Finds when it is running for the supplied logdir location and return its TCP port.

Parameters:

Name Type Description Default
logdir str

The logdir location where the TensorBoard server is running.

required

Returns:

Type Description
Optional[TensorBoardInfo]

The TensorBoardInfo describing the running TensorBoard server or

Optional[TensorBoardInfo]

None if no server is running for the supplied logdir location.

Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@classmethod
def find_running_tensorboard_server(
    cls, logdir: str
) -> Optional[TensorBoardInfo]:
    """Find a local TensorBoard server instance.

    Finds when it is running for the supplied logdir location and return its
    TCP port.

    Args:
        logdir: The logdir location where the TensorBoard server is running.

    Returns:
        The TensorBoardInfo describing the running TensorBoard server or
        None if no server is running for the supplied logdir location.
    """
    for server in get_all():
        if (
            server.logdir == logdir
            and server.pid
            and psutil.pid_exists(server.pid)
        ):
            return server
    return None
stop(object: StepRunResponse) -> None

Stop the TensorBoard server previously started for a pipeline step.

Parameters:

Name Type Description Default
object StepRunResponse

StepRunResponseModel fetched from get_step().

required
Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def stop(
    self,
    object: "StepRunResponse",
) -> None:
    """Stop the TensorBoard server previously started for a pipeline step.

    Args:
        object: StepRunResponseModel fetched from get_step().
    """
    for output in object.outputs.values():
        for artifact_view in output:
            # filter out anything but model artifacts
            if artifact_view.type == ArtifactType.MODEL:
                logdir = os.path.dirname(artifact_view.uri)

                # first check if a TensorBoard server is already running for
                # the same logdir location and use that one
                running_server = self.find_running_tensorboard_server(
                    logdir
                )
                if not running_server:
                    return

                logger.debug(
                    "Stopping tensorboard server with PID '%d' ...",
                    running_server.pid,
                )
                try:
                    p = psutil.Process(running_server.pid)
                except psutil.Error:
                    logger.error(
                        "Could not find process for PID '%d' ...",
                        running_server.pid,
                    )
                    continue
                p.kill()
                return
visualize(object: StepRunResponse, height: int = 800, *args: Any, **kwargs: Any) -> None

Start a TensorBoard server.

Allows for the visualization of all models logged as artifacts by the indicated step. The server will monitor and display all the models logged by past and future step runs.

Parameters:

Name Type Description Default
object StepRunResponse

StepRunResponseModel fetched from get_step().

required
height int

Height of the generated visualization.

800
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def visualize(
    self,
    object: "StepRunResponse",
    height: int = 800,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Start a TensorBoard server.

    Allows for the visualization of all models logged as artifacts by the
    indicated step. The server will monitor and display all the models
    logged by past and future step runs.

    Args:
        object: StepRunResponseModel fetched from get_step().
        height: Height of the generated visualization.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for output in object.outputs.values():
        for artifact_view in output:
            # filter out anything but model artifacts
            if artifact_view.type == ArtifactType.MODEL:
                logdir = os.path.dirname(artifact_view.uri)

                # first check if a TensorBoard server is already running for
                # the same logdir location and use that one
                running_server = self.find_running_tensorboard_server(
                    logdir
                )
                if running_server:
                    self.visualize_tensorboard(running_server.port, height)
                    return

                if sys.platform == "win32":
                    # Daemon service functionality is currently not supported
                    # on Windows
                    print(
                        "You can run:\n"
                        f"[italic green]    tensorboard --logdir {logdir}"
                        "[/italic green]\n"
                        "...to visualize the TensorBoard logs for your trained model."
                    )
                else:
                    # start a new TensorBoard server
                    service = TensorboardService(
                        TensorboardServiceConfig(
                            logdir=logdir,
                            name=f"zenml-tensorboard-{logdir}",
                        )
                    )
                    service.start(timeout=60)
                    if service.endpoint.status.port:
                        self.visualize_tensorboard(
                            service.endpoint.status.port, height
                        )
                return
visualize_tensorboard(port: int, height: int) -> None

Generate a visualization of a TensorBoard.

Parameters:

Name Type Description Default
port int

the TCP port where the TensorBoard server is listening for requests.

required
height int

Height of the generated visualization.

required
Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def visualize_tensorboard(
    self,
    port: int,
    height: int,
) -> None:
    """Generate a visualization of a TensorBoard.

    Args:
        port: the TCP port where the TensorBoard server is listening for
            requests.
        height: Height of the generated visualization.
    """
    if Environment.in_notebook():
        notebook.display(port, height=height)
        return

    print(
        "You can visit:\n"
        f"[italic green]    http://localhost:{port}/[/italic green]\n"
        "...to visualize the TensorBoard logs for your trained model."
    )
Functions
get_step(pipeline_name: str, step_name: str) -> StepRunResponse

Get the StepRunResponseModel for the specified pipeline and step name.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required
step_name str

The name of the step.

required

Returns:

Type Description
StepRunResponse

The StepRunResponseModel for the specified pipeline and step name.

Raises:

Type Description
RuntimeError

If the step is not found.

Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def get_step(pipeline_name: str, step_name: str) -> "StepRunResponse":
    """Get the StepRunResponseModel for the specified pipeline and step name.

    Args:
        pipeline_name: The name of the pipeline.
        step_name: The name of the step.

    Returns:
        The StepRunResponseModel for the specified pipeline and step name.

    Raises:
        RuntimeError: If the step is not found.
    """
    runs = Client().list_pipeline_runs(pipeline=pipeline_name)
    if runs.total == 0:
        raise RuntimeError(
            f"No pipeline runs for pipeline `{pipeline_name}` were found"
        )

    last_run = runs[0]
    if step_name not in last_run.steps:
        raise RuntimeError(
            f"No pipeline step with name `{step_name}` was found in "
            f"pipeline `{pipeline_name}`"
        )
    step = last_run.steps[step_name]
    return step
stop_tensorboard_server(pipeline_name: str, step_name: str) -> None

Stop the TensorBoard server previously started for a pipeline step.

Parameters:

Name Type Description Default
pipeline_name str

the name of the pipeline

required
step_name str

pipeline step name

required
Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
234
235
236
237
238
239
240
241
242
def stop_tensorboard_server(pipeline_name: str, step_name: str) -> None:
    """Stop the TensorBoard server previously started for a pipeline step.

    Args:
        pipeline_name: the name of the pipeline
        step_name: pipeline step name
    """
    step = get_step(pipeline_name, step_name)
    TensorboardVisualizer().stop(step)
visualize_tensorboard(pipeline_name: str, step_name: str) -> None

Start a TensorBoard server.

Allows for the visualization of all models logged as output by the named pipeline step. The server will monitor and display all the models logged by past and future step runs.

Parameters:

Name Type Description Default
pipeline_name str

the name of the pipeline

required
step_name str

pipeline step name

required
Source code in src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py
219
220
221
222
223
224
225
226
227
228
229
230
231
def visualize_tensorboard(pipeline_name: str, step_name: str) -> None:
    """Start a TensorBoard server.

    Allows for the visualization of all models logged as output by the named
    pipeline step. The server will monitor and display all the models logged by
    past and future step runs.

    Args:
        pipeline_name: the name of the pipeline
        step_name: pipeline step name
    """
    step = get_step(pipeline_name, step_name)
    TensorboardVisualizer().visualize(step)