Wandb
zenml.integrations.wandb
special
Initialization for the wandb integration.
The wandb integrations currently enables you to use wandb tracking as a convenient way to visualize your experiment runs within the wandb ui.
WandbIntegration (Integration)
Definition of Plotly integration for ZenML.
Source code in zenml/integrations/wandb/__init__.py
class WandbIntegration(Integration):
"""Definition of Plotly integration for ZenML."""
NAME = WANDB
REQUIREMENTS = ["wandb>=0.12.12", "Pillow>=9.1.0"]
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["Pillow"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Weights and Biases integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.wandb.flavors import (
WandbExperimentTrackerFlavor,
)
return [WandbExperimentTrackerFlavor]
flavors()
classmethod
Declare the stack component flavors for the Weights and Biases integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/wandb/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Weights and Biases integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.wandb.flavors import (
WandbExperimentTrackerFlavor,
)
return [WandbExperimentTrackerFlavor]
experiment_trackers
special
Initialization for the wandb experiment tracker.
wandb_experiment_tracker
Implementation for the wandb experiment tracker.
WandbExperimentTracker (BaseExperimentTracker)
Track experiment using Wandb.
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
class WandbExperimentTracker(BaseExperimentTracker):
"""Track experiment using Wandb."""
@property
def config(self) -> WandbExperimentTrackerConfig:
"""Returns the `WandbExperimentTrackerConfig` config.
Returns:
The configuration.
"""
return cast(WandbExperimentTrackerConfig, self._config)
@property
def settings_class(self) -> Type[WandbExperimentTrackerSettings]:
"""Settings class for the Wandb experiment tracker.
Returns:
The settings class.
"""
return WandbExperimentTrackerSettings
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Configures a Wandb run.
Args:
info: Info about the step that will be executed.
"""
os.environ[WANDB_API_KEY] = self.config.api_key
settings = cast(
WandbExperimentTrackerSettings, self.get_settings(info)
)
tags = settings.tags + [info.run_name, info.pipeline.name]
wandb_run_name = (
settings.run_name or f"{info.run_name}_{info.pipeline_step_name}"
)
self._initialize_wandb(
run_name=wandb_run_name, tags=tags, settings=settings.settings
)
def get_step_run_metadata(
self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
"""Get component- and step-specific metadata after a step ran.
Args:
info: Info about the step that was executed.
Returns:
A dictionary of metadata.
"""
run_url: Optional[str] = None
run_name: Optional[str] = None
# Try to get the run name and URL from WandB directly
current_wandb_run = wandb.run
if current_wandb_run:
run_url = current_wandb_run.get_url()
run_name = current_wandb_run.name
# If the URL cannot be retrieved, use the default run URL
default_run_url = (
f"https://wandb.ai/{self.config.entity}/"
f"{self.config.project_name}/runs/"
)
run_url = run_url or default_run_url
# If the run name cannot be retrieved, use the default run name
default_run_name = f"{info.run_name}_{info.pipeline_step_name}"
settings = cast(
WandbExperimentTrackerSettings, self.get_settings(info)
)
run_name = run_name or settings.run_name or default_run_name
return {
METADATA_EXPERIMENT_TRACKER_URL: Uri(run_url),
"wandb_run_name": run_name,
}
def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None:
"""Stops the Wandb run.
Args:
info: Info about the step that was executed.
step_failed: Whether the step failed or not.
"""
wandb.finish(exit_code=1) if step_failed else wandb.finish()
os.environ.pop(WANDB_API_KEY, None)
def _initialize_wandb(
self,
run_name: str,
tags: List[str],
settings: Union["Settings", Dict[str, Any], None] = None,
) -> None:
"""Initializes a wandb run.
Args:
run_name: Name of the wandb run to create.
tags: Tags to attach to the wandb run.
settings: Additional settings for the wandb run.
"""
logger.info(
f"Initializing wandb with entity {self.config.entity}, project "
f"name: {self.config.project_name}, run_name: {run_name}."
)
wandb.init(
entity=self.config.entity,
project=self.config.project_name,
name=run_name,
tags=tags,
settings=settings,
)
config: WandbExperimentTrackerConfig
property
readonly
Returns the WandbExperimentTrackerConfig
config.
Returns:
Type | Description |
---|---|
WandbExperimentTrackerConfig |
The configuration. |
settings_class: Type[zenml.integrations.wandb.flavors.wandb_experiment_tracker_flavor.WandbExperimentTrackerSettings]
property
readonly
Settings class for the Wandb experiment tracker.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.wandb.flavors.wandb_experiment_tracker_flavor.WandbExperimentTrackerSettings] |
The settings class. |
cleanup_step_run(self, info, step_failed)
Stops the Wandb run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
step_failed |
bool |
Whether the step failed or not. |
required |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None:
"""Stops the Wandb run.
Args:
info: Info about the step that was executed.
step_failed: Whether the step failed or not.
"""
wandb.finish(exit_code=1) if step_failed else wandb.finish()
os.environ.pop(WANDB_API_KEY, None)
get_step_run_metadata(self, info)
Get component- and step-specific metadata after a step ran.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def get_step_run_metadata(
self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
"""Get component- and step-specific metadata after a step ran.
Args:
info: Info about the step that was executed.
Returns:
A dictionary of metadata.
"""
run_url: Optional[str] = None
run_name: Optional[str] = None
# Try to get the run name and URL from WandB directly
current_wandb_run = wandb.run
if current_wandb_run:
run_url = current_wandb_run.get_url()
run_name = current_wandb_run.name
# If the URL cannot be retrieved, use the default run URL
default_run_url = (
f"https://wandb.ai/{self.config.entity}/"
f"{self.config.project_name}/runs/"
)
run_url = run_url or default_run_url
# If the run name cannot be retrieved, use the default run name
default_run_name = f"{info.run_name}_{info.pipeline_step_name}"
settings = cast(
WandbExperimentTrackerSettings, self.get_settings(info)
)
run_name = run_name or settings.run_name or default_run_name
return {
METADATA_EXPERIMENT_TRACKER_URL: Uri(run_url),
"wandb_run_name": run_name,
}
prepare_step_run(self, info)
Configures a Wandb run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that will be executed. |
required |
Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Configures a Wandb run.
Args:
info: Info about the step that will be executed.
"""
os.environ[WANDB_API_KEY] = self.config.api_key
settings = cast(
WandbExperimentTrackerSettings, self.get_settings(info)
)
tags = settings.tags + [info.run_name, info.pipeline.name]
wandb_run_name = (
settings.run_name or f"{info.run_name}_{info.pipeline_step_name}"
)
self._initialize_wandb(
run_name=wandb_run_name, tags=tags, settings=settings.settings
)
flavors
special
Weights & Biases integration flavors.
wandb_experiment_tracker_flavor
Weights & Biases experiment tracker flavor.
WandbExperimentTrackerConfig (BaseExperimentTrackerConfig, WandbExperimentTrackerSettings)
Config for the Wandb experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
entity |
Optional[str] |
Name of an existing wandb entity. |
project_name |
Optional[str] |
Name of an existing wandb project to log to. |
api_key |
str |
API key to should be authorized to log to the configured wandb entity and project. |
Source code in zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py
class WandbExperimentTrackerConfig(
BaseExperimentTrackerConfig, WandbExperimentTrackerSettings
):
"""Config for the Wandb experiment tracker.
Attributes:
entity: Name of an existing wandb entity.
project_name: Name of an existing wandb project to log to.
api_key: API key to should be authorized to log to the configured wandb
entity and project.
"""
api_key: str = SecretField()
entity: Optional[str] = None
project_name: Optional[str] = None
WandbExperimentTrackerFlavor (BaseExperimentTrackerFlavor)
Flavor for the Wandb experiment tracker.
Source code in zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py
class WandbExperimentTrackerFlavor(BaseExperimentTrackerFlavor):
"""Flavor for the Wandb experiment tracker."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return WANDB_EXPERIMENT_TRACKER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A URL to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A URL to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A URL to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/wandb.png"
@property
def config_class(self) -> Type[WandbExperimentTrackerConfig]:
"""Returns `WandbExperimentTrackerConfig` config class.
Returns:
The config class.
"""
return WandbExperimentTrackerConfig
@property
def implementation_class(self) -> Type["WandbExperimentTracker"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.wandb.experiment_trackers import (
WandbExperimentTracker,
)
return WandbExperimentTracker
config_class: Type[zenml.integrations.wandb.flavors.wandb_experiment_tracker_flavor.WandbExperimentTrackerConfig]
property
readonly
Returns WandbExperimentTrackerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.wandb.flavors.wandb_experiment_tracker_flavor.WandbExperimentTrackerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A URL to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[WandbExperimentTracker]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[WandbExperimentTracker] |
The implementation class. |
logo_url: str
property
readonly
A URL to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A URL to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
WandbExperimentTrackerSettings (BaseSettings)
Settings for the Wandb experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
run_name |
Optional[str] |
The Wandb run name. |
tags |
List[str] |
Tags for the Wandb run. |
settings |
Dict[str, Any] |
Settings for the Wandb run. |
Source code in zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py
class WandbExperimentTrackerSettings(BaseSettings):
"""Settings for the Wandb experiment tracker.
Attributes:
run_name: The Wandb run name.
tags: Tags for the Wandb run.
settings: Settings for the Wandb run.
"""
run_name: Optional[str] = None
tags: List[str] = []
settings: Dict[str, Any] = {}
@field_validator("settings", mode="before")
@classmethod
def _convert_settings(cls, value: Any) -> Any:
"""Converts settings to a dictionary.
Args:
value: The settings.
Raises:
ValueError: If converting the settings failed.
Returns:
Dict representation of the settings.
"""
import wandb
if isinstance(value, wandb.Settings):
# Depending on the wandb version, either `model_dump`,
# `make_static` or `to_dict` is available to convert the settings
# to a dictionary
if isinstance(value, BaseModel):
return value.model_dump()
elif hasattr(value, "make_static"):
return cast(Dict[str, Any], value.make_static())
elif hasattr(value, "to_dict"):
return value.to_dict()
else:
raise ValueError("Unable to convert wandb settings to dict.")
else:
return value