Label Studio
zenml.integrations.label_studio
special
Initialization of the Label Studio integration.
LabelStudioIntegration (Integration)
Definition of Label Studio integration for ZenML.
Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
"""Definition of Label Studio integration for ZenML."""
NAME = LABEL_STUDIO
REQUIREMENTS = ["label-studio==1.6.0", "label-studio-sdk==0.0.15"]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.label_studio.flavors import (
LabelStudioAnnotatorFlavor,
)
return [LabelStudioAnnotatorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Label Studio integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.label_studio.flavors import (
LabelStudioAnnotatorFlavor,
)
return [LabelStudioAnnotatorFlavor]
annotators
special
Initialization of the Label Studio annotators submodule.
label_studio_annotator
Implementation of the Label Studio annotation integration.
LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin)
Class to interact with the Label Studio annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
"""Class to interact with the Label Studio annotation interface."""
@property
def config(self) -> LabelStudioAnnotatorConfig:
"""Returns the `LabelStudioAnnotatorConfig` config.
Returns:
The configuration.
"""
return cast(LabelStudioAnnotatorConfig, self._config)
@property
def validator(self) -> Optional["StackValidator"]:
"""Validates that the stack contains a cloud artifact store.
Returns:
StackValidator: Validator for the stack.
"""
def _ensure_cloud_artifact_stores(stack: Stack) -> Tuple[bool, str]:
# For now this only works on cloud artifact stores.
return (
stack.artifact_store.flavor
in [
AZURE_ARTIFACT_STORE_FLAVOR,
GCP_ARTIFACT_STORE_FLAVOR,
S3_ARTIFACT_STORE_FLAVOR,
],
"Only cloud artifact stores are currently supported",
)
return StackValidator(
required_components={StackComponentType.SECRETS_MANAGER},
custom_validation_function=_ensure_cloud_artifact_stores,
)
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"{self.config.instance_url}:{self.config.port}"
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
def _get_client(self) -> Client:
"""Gets Label Studio client.
Returns:
Label Studio client.
Raises:
ValueError: when unable to access the Label Studio API key.
"""
secret = self.get_authentication_secret(ArbitrarySecretSchema)
if not secret:
raise ValueError(
f"Unable to access predefined secret '{secret}' to access Label Studio API key."
)
api_key = secret.content["api_key"]
return Client(url=self.get_url(), api_key=api_key)
def _connection_available(self) -> bool:
"""Checks if the connection to the annotation server is available.
Returns:
True if the connection is available, False otherwise.
"""
try:
result = self._get_client().check_connection()
return result.get("status") == "UP" # type: ignore[no-any-return]
# TODO: [HIGH] refactor to use a more specific exception
except Exception:
logger.error(
"Connection error: No connection was able to be established to the Label Studio backend."
)
return False
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
NotImplementedError: If the deletion of a dataset is not supported.
"""
raise NotImplementedError("Awaiting Label Studio release.")
# TODO: Awaiting a new Label Studio version to be released with this method
# ls = self._get_client()
# dataset_name = kwargs.get("dataset_name")
# if not dataset_name:
# raise ValueError("`dataset_name` keyword argument is required.")
# dataset_id = self.get_id_from_name(dataset_name)
# if not dataset_id:
# raise ValueError(
# f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
# )
# ls.delete_project(dataset_id)
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
def register_dataset_for_annotation(
self,
params: LabelStudioDatasetRegistrationParameters,
) -> Any:
"""Registers a dataset for annotation.
Args:
params: Parameters for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(params.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=params.dataset_name,
label_config=params.label_config,
)
return dataset
def _get_azure_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Azure import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Azure import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/azure?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_gcs_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Google Cloud Storage import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Google Cloud Storage import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/gcs?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_s3_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all AWS S3 import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of AWS S3 import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/s3?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _storage_source_already_exists(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> bool:
"""Returns whether a storage source already exists.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
True if the storage source already exists, False otherwise.
Raises:
NotImplementedError: If the storage source type is not supported.
"""
# TODO: check we are already connected
dataset_id = int(dataset.get_params()["id"])
if params.storage_type == "azure":
storage_sources = self._get_azure_import_storage_sources(
dataset_id
)
elif params.storage_type == "gcs":
storage_sources = self._get_gcs_import_storage_sources(dataset_id)
elif params.storage_type == "s3":
storage_sources = self._get_s3_import_storage_sources(dataset_id)
else:
raise NotImplementedError(
f"Storage type '{params.storage_type}' not implemented."
)
return any(
(
source.get("presign") == params.presign
and source.get("bucket") == uri
and source.get("regex_filter") == params.regex_filter
and source.get("use_blob_urls") == params.use_blob_urls
and source.get("title") == dataset.get_params()["title"]
and source.get("description") == params.description
and source.get("presign_ttl") == params.presign_ttl
and source.get("project") == dataset_id
)
for source in storage_sources
)
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
def connect_and_sync_external_storage(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
}
if params.storage_type == "azure":
if not params.azure_account_name or not params.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=params.azure_account_name,
account_key=params.azure_account_key,
**storage_connection_args,
)
elif params.storage_type == "gcs":
if not params.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=params.google_application_credentials,
**storage_connection_args,
)
elif params.storage_type == "s3":
if (
not params.aws_access_key_id
or not params.aws_secret_access_key
):
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
" Label Studio web UI as per your needs."
)
storage = dataset.connect_s3_import_storage(
bucket=uri,
aws_access_key_id=params.aws_access_key_id,
aws_secret_access_key=params.aws_secret_access_key,
aws_session_token=params.aws_session_token,
region_name=params.s3_region_name,
s3_endpoint=params.s3_endpoint,
**storage_connection_args,
)
else:
raise ValueError(
f"Invalid storage type. '{params.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
@property
def root_directory(self) -> str:
"""Returns path to the root directory.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"annotators",
str(self.id),
)
@property
def _pid_file_path(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.root_directory, "label_studio_daemon.pid")
@property
def _log_file(self) -> str:
"""Path of the daemon log file.
Returns:
Path to the daemon log file.
"""
return os.path.join(self.root_directory, "label_studio_daemon.log")
@property
def is_provisioned(self) -> bool:
"""If the component provisioned resources to run locally.
Returns:
True if the component provisioned resources to run locally.
"""
return fileio.exists(self.root_directory)
@property
def is_running(self) -> bool:
"""If the component is running locally.
Returns:
True if the component is running locally, False otherwise.
"""
if not self.is_local_instance:
return True
if sys.platform != "win32":
from zenml.utils.daemon import check_if_daemon_is_running
if not check_if_daemon_is_running(self._pid_file_path):
return False
else:
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
return True
@property
def is_local_instance(self) -> bool:
"""Determines if the Label Studio instance is running locally.
Returns:
True if the component is running locally, False otherwise.
"""
return self.config.instance_url == DEFAULT_LOCAL_INSTANCE_URL
def provision(self) -> None:
"""Spins up the annotation server backend."""
fileio.makedirs(self.root_directory)
def deprovision(self) -> None:
"""Spins down the annotation server backend."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
def resume(self) -> None:
"""Resumes the annotation interface."""
if self.is_running:
logger.info("Local annotation deployment already running.")
return
if self.is_local_instance:
self.start_annotator_daemon()
def suspend(self) -> None:
"""Suspends the annotation interface."""
if not self.is_running:
logger.info("Local annotation server is not running.")
return
if self.is_local_instance:
self.stop_annotator_daemon()
def start_annotator_daemon(self) -> None:
"""Starts the annotation server backend.
Raises:
ProvisioningError: If the annotation server backend is already
running or the port is already occupied.
"""
command = [
"label-studio",
"start",
"--no-browser",
"--port",
f"{self.config.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Label Studio server locally, "
"please run '%s' in a separate command line shell.",
self.config.port,
" ".join(command),
)
elif not networking_utils.port_available(self.config.port):
raise ProvisioningError(
f"Unable to port-forward Label Studio to local "
f"port {self.config.port} because the port is occupied. In order to "
f"access Label Studio locally, please "
f"change the configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Label Studio daemon (check the daemon"
"logs at `%s` in case you're not able to access the annotation "
f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
self._log_file,
)
def stop_annotator_daemon(self) -> None:
"""Stops the annotation server backend."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
config: LabelStudioAnnotatorConfig
property
readonly
Returns the LabelStudioAnnotatorConfig
config.
Returns:
Type | Description |
---|---|
LabelStudioAnnotatorConfig |
The configuration. |
is_local_instance: bool
property
readonly
Determines if the Label Studio instance is running locally.
Returns:
Type | Description |
---|---|
bool |
True if the component is running locally, False otherwise. |
is_provisioned: bool
property
readonly
If the component provisioned resources to run locally.
Returns:
Type | Description |
---|---|
bool |
True if the component provisioned resources to run locally. |
is_running: bool
property
readonly
If the component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if the component is running locally, False otherwise. |
root_directory: str
property
readonly
Returns path to the root directory.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
validator: Optional[StackValidator]
property
readonly
Validates that the stack contains a cloud artifact store.
Returns:
Type | Description |
---|---|
StackValidator |
Validator for the stack. |
add_dataset(self, **kwargs)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Exceptions:
Type | Description |
---|---|
ValueError |
if 'dataset_name' and 'label_config' aren't provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
connect_and_sync_external_storage(self, uri, params, dataset)
Syncs the external storage for the given project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI of the storage source. |
required |
params |
LabelStudioDatasetSyncParameters |
Parameters for the dataset. |
required |
dataset |
Project |
Label Studio dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[Dict[str, Any]] |
A dictionary containing the sync result. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the storage type is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
}
if params.storage_type == "azure":
if not params.azure_account_name or not params.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=params.azure_account_name,
account_key=params.azure_account_key,
**storage_connection_args,
)
elif params.storage_type == "gcs":
if not params.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=params.google_application_credentials,
**storage_connection_args,
)
elif params.storage_type == "s3":
if (
not params.aws_access_key_id
or not params.aws_secret_access_key
):
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
" Label Studio web UI as per your needs."
)
storage = dataset.connect_s3_import_storage(
bucket=uri,
aws_access_key_id=params.aws_access_key_id,
aws_secret_access_key=params.aws_secret_access_key,
aws_session_token=params.aws_session_token,
region_name=params.s3_region_name,
s3_endpoint=params.s3_endpoint,
**storage_connection_args,
)
else:
raise ValueError(
f"Invalid storage type. '{params.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)
Deletes a dataset from the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
If the deletion of a dataset is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
NotImplementedError: If the deletion of a dataset is not supported.
"""
raise NotImplementedError("Awaiting Label Studio release.")
# TODO: Awaiting a new Label Studio version to be released with this method
# ls = self._get_client()
# dataset_name = kwargs.get("dataset_name")
# if not dataset_name:
# raise ValueError("`dataset_name` keyword argument is required.")
# dataset_id = self.get_id_from_name(dataset_name)
# if not dataset_id:
# raise ValueError(
# f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
# )
# ls.delete_project(dataset_id)
deprovision(self)
Spins down the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def deprovision(self) -> None:
"""Spins down the annotation server backend."""
if fileio.exists(self._log_file):
fileio.remove(self._log_file)
get_converted_dataset(self, dataset_name, output_format)
Extract annotated tasks in a specific converted format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Id of the dataset. |
required |
output_format |
str |
Output format. |
required |
Returns:
Type | Description |
---|---|
Dict[Any, Any] |
A dictionary containing the converted dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
get_dataset(self, **kwargs)
Gets the dataset with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The LabelStudio Dataset object (a 'Project') for the given name. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
get_dataset_names(self)
Gets the names of the datasets.
Returns:
Type | Description |
---|---|
List[str] |
A list of dataset names. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
get_dataset_stats(self, dataset_name)
Gets the statistics of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Tuple[int, int] |
A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset. |
Exceptions:
Type | Description |
---|---|
IndexError |
If the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
get_datasets(self)
Gets the datasets currently available for annotation.
Returns:
Type | Description |
---|---|
List[Any] |
A list of datasets. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)
Gets the ID of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[int] |
The ID of the dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
get_labeled_data(self, **kwargs)
Gets the labeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The labeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)
Returns the parsed Label Studio label config for a dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_id |
int |
Id of the dataset. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the parsed label config. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no dataset is found for the given id. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)
Gets the unlabeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
str |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The unlabeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)
Gets the top-level URL of the annotation interface.
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"{self.config.instance_url}:{self.config.port}"
get_url_for_dataset(self, dataset_name)
Gets the URL of the annotation interface for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)
Launches the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
Optional[str] |
The URL of the annotation interface. |
required |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
provision(self)
Spins up the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def provision(self) -> None:
"""Spins up the annotation server backend."""
fileio.makedirs(self.root_directory)
register_dataset_for_annotation(self, params)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetRegistrationParameters |
Parameters for the dataset. |
required |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
self,
params: LabelStudioDatasetRegistrationParameters,
) -> Any:
"""Registers a dataset for annotation.
Args:
params: Parameters for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(params.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=params.dataset_name,
label_config=params.label_config,
)
return dataset
resume(self)
Resumes the annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def resume(self) -> None:
"""Resumes the annotation interface."""
if self.is_running:
logger.info("Local annotation deployment already running.")
return
if self.is_local_instance:
self.start_annotator_daemon()
start_annotator_daemon(self)
Starts the annotation server backend.
Exceptions:
Type | Description |
---|---|
ProvisioningError |
If the annotation server backend is already running or the port is already occupied. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def start_annotator_daemon(self) -> None:
"""Starts the annotation server backend.
Raises:
ProvisioningError: If the annotation server backend is already
running or the port is already occupied.
"""
command = [
"label-studio",
"start",
"--no-browser",
"--port",
f"{self.config.port}",
]
if sys.platform == "win32":
logger.warning(
"Daemon functionality not supported on Windows. "
"In order to access the Label Studio server locally, "
"please run '%s' in a separate command line shell.",
self.config.port,
" ".join(command),
)
elif not networking_utils.port_available(self.config.port):
raise ProvisioningError(
f"Unable to port-forward Label Studio to local "
f"port {self.config.port} because the port is occupied. In order to "
f"access Label Studio locally, please "
f"change the configuration to use an available "
f"port or stop the other process currently using the port."
)
else:
from zenml.utils import daemon
def _daemon_function() -> None:
"""Forwards the port of the Kubeflow Pipelines Metadata pod ."""
subprocess.check_call(command)
daemon.run_as_daemon(
_daemon_function,
pid_file=self._pid_file_path,
log_file=self._log_file,
)
logger.info(
"Started Label Studio daemon (check the daemon"
"logs at `%s` in case you're not able to access the annotation "
f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
self._log_file,
)
stop_annotator_daemon(self)
Stops the annotation server backend.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def stop_annotator_daemon(self) -> None:
"""Stops the annotation server backend."""
if fileio.exists(self._pid_file_path):
if sys.platform == "win32":
# Daemon functionality is not supported on Windows, so the PID
# file won't exist. This if clause exists just for mypy to not
# complain about missing functions
pass
else:
from zenml.utils import daemon
daemon.stop_daemon(self._pid_file_path)
fileio.remove(self._pid_file_path)
suspend(self)
Suspends the annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def suspend(self) -> None:
"""Suspends the annotation interface."""
if not self.is_running:
logger.info("Local annotation server is not running.")
return
if self.is_local_instance:
self.stop_annotator_daemon()
flavors
special
Label Studio integration flavors.
label_studio_annotator_flavor
Label Studio annotator flavor.
LabelStudioAnnotatorConfig (BaseAnnotatorConfig, AuthenticationConfigMixin)
pydantic-model
Config for the Label Studio annotator.
Attributes:
Name | Type | Description |
---|---|---|
instance_url |
str |
URL of the Label Studio instance. |
port |
int |
The port to use for the annotation interface. |
Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorConfig(
BaseAnnotatorConfig, AuthenticationConfigMixin
):
"""Config for the Label Studio annotator.
Attributes:
instance_url: URL of the Label Studio instance.
port: The port to use for the annotation interface.
"""
instance_url: str = DEFAULT_LOCAL_INSTANCE_URL
port: int = DEFAULT_LOCAL_LABEL_STUDIO_PORT
LabelStudioAnnotatorFlavor (BaseAnnotatorFlavor)
Label Studio annotator flavor.
Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorFlavor(BaseAnnotatorFlavor):
"""Label Studio annotator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return LABEL_STUDIO_ANNOTATOR_FLAVOR
@property
def config_class(self) -> Type[LabelStudioAnnotatorConfig]:
"""Returns `LabelStudioAnnotatorConfig` config class.
Returns:
The config class.
"""
return LabelStudioAnnotatorConfig
@property
def implementation_class(self) -> Type["LabelStudioAnnotator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.label_studio.annotators import (
LabelStudioAnnotator,
)
return LabelStudioAnnotator
config_class: Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig]
property
readonly
Returns LabelStudioAnnotatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig] |
The config class. |
implementation_class: Type[LabelStudioAnnotator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[LabelStudioAnnotator] |
The implementation class. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
label_config_generators
special
Initialization of the Label Studio config generators submodule.
label_config_generators
Implementation of label config generators for Label Studio.
generate_basic_object_detection_bounding_boxes_label_config(labels)
Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at
https://labelstud.io/templates/image_bbox.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES
label_config_start = """<View>
<Image name="image" value="$image"/>
<RectangleLabels name="label" toName="image">
"""
label_config_choices = "".join(
f"<Label value='{label}' />\n" for label in labels
)
label_config_end = "</RectangleLabels>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
generate_basic_ocr_label_config(labels)
Generates a Label Studio config for optical character recognition (OCR) labeling task.
This is based on the basic config example shown at https://labelstud.io/templates/optical_character_recognition.html
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_ocr_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio config for optical character recognition (OCR) labeling task.
This is based on the basic config example shown at
https://labelstud.io/templates/optical_character_recognition.html
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.OCR
label_config_start = """
<View>
<Image name="image" value="$ocr" zoom="true" zoomControl="true" rotateControl="true"/>
<View>
<Filter toName="label" minlength="0" name="filter"/>
<Labels name="label" toName="image">
"""
label_config_choices = "".join(
f"<Label value='{label}' />\n" for label in labels
)
label_config_end = """
</Labels>
</View>
<Rectangle name="bbox" toName="image" strokeWidth="3"/>
<Polygon name="poly" toName="image" strokeWidth="3"/>
<TextArea name="transcription" toName="image" editable="true" perRegion="true" required="true" maxSubmissions="1" rows="5" placeholder="Recognized Text" displayMode="region-list"/>
</View>
"""
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
generate_image_classification_label_config(labels)
Generates a Label Studio label config for image classification.
This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio label config for image classification.
This is based on the basic config example shown at
https://labelstud.io/templates/image_classification.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION
label_config_start = """<View>
<Image name="image" value="$image"/>
<Choices name="choice" toName="image">
"""
label_config_choices = "".join(
f"<Choice value='{label}' />\n" for label in labels
)
label_config_end = "</Choices>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
label_studio_utils
Utility functions for the Label Studio annotator integration.
convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)
Converts a list of predictions from local file references to task id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
List[Dict[str, Any]] |
List of predictions. |
required |
tasks |
List[Dict[str, Any]] |
List of tasks. |
required |
filename_reference |
str |
Name of the file reference in the predictions. |
required |
storage_type |
str |
Storage type of the predictions. |
required |
Returns:
Type | Description |
---|---|
List[Dict[str, Any]] |
List of predictions using task ids as reference. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
preds: List[Dict[str, Any]],
tasks: List[Dict[str, Any]],
filename_reference: str,
storage_type: str,
) -> List[Dict[str, Any]]:
"""Converts a list of predictions from local file references to task id.
Args:
preds: List of predictions.
tasks: List of tasks.
filename_reference: Name of the file reference in the predictions.
storage_type: Storage type of the predictions.
Returns:
List of predictions using task ids as reference.
"""
filename_id_mapping = {
os.path.basename(
urlparse(task["data"][filename_reference]).path
): task["id"]
for task in tasks
}
# GCS and S3 URL encodes filenames containing spaces, requiring this
# separate encoding step
if storage_type in {"gcs", "s3"}:
preds = [
{"filename": quote(pred["filename"]), "result": pred["result"]}
for pred in preds
]
return [
{
"task": int(
filename_id_mapping[os.path.basename(pred["filename"])]
),
"result": pred["result"],
}
for pred in preds
]
get_file_extension(path_str)
Return the file extension of the given filename.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path_str |
str |
Path to the file. |
required |
Returns:
Type | Description |
---|---|
str |
File extension. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
"""Return the file extension of the given filename.
Args:
path_str: Path to the file.
Returns:
File extension.
"""
return os.path.splitext(urlparse(path_str).path)[1]
is_azure_url(url)
Return whether the given URL is an Azure URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an Azure URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
"""Return whether the given URL is an Azure URL.
Args:
url: URL to check.
Returns:
True if the URL is an Azure URL, False otherwise.
"""
return "blob.core.windows.net" in urlparse(url).netloc
is_gcs_url(url)
Return whether the given URL is an GCS URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an GCS URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
"""Return whether the given URL is an GCS URL.
Args:
url: URL to check.
Returns:
True if the URL is an GCS URL, False otherwise.
"""
return "storage.googleapis.com" in urlparse(url).netloc
is_s3_url(url)
Return whether the given URL is an S3 URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an S3 URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
"""Return whether the given URL is an S3 URL.
Args:
url: URL to check.
Returns:
True if the URL is an S3 URL, False otherwise.
"""
return "s3.amazonaws" in urlparse(url).netloc
steps
special
Standard steps to be used with the Label Studio annotator integration.
label_studio_standard_steps
Implementation of standard steps for the Label Studio annotator integration.
LabelStudioDatasetRegistrationParameters (BaseParameters)
pydantic-model
Step parameters when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
"""Step parameters when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
LabelStudioDatasetSyncParameters (BaseParameters)
pydantic-model
Step parameters when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
"""Step parameters when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to.
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
get_labeled_data (BaseStep)
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
Name of the dataset. |
required | |
context |
The StepContext. |
required |
Returns:
Type | Description |
---|---|
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
entrypoint(dataset_name, context)
staticmethod
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Name of the dataset. |
required |
context |
StepContext |
The StepContext. |
required |
Returns:
Type | Description |
---|---|
List |
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str, context: StepContext) -> List: # type: ignore[type-arg]
"""Gets labeled data from the dataset.
Args:
dataset_name: Name of the dataset.
context: The StepContext.
Returns:
List of labeled data.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
# TODO [MEDIUM]: have this check for new data *since the last time this step ran*
annotator = context.stack.annotator # type: ignore[union-attr]
if not annotator:
raise StackComponentInterfaceError("No active annotator.")
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator._connection_available():
dataset = annotator.get_dataset(dataset_name=dataset_name)
return dataset.get_labeled_tasks() # type: ignore[no-any-return]
raise StackComponentInterfaceError(
"Unable to connect to annotator stack component."
)
get_or_create_dataset (BaseStep)
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Step parameters. |
required | |
context |
Step context. |
required |
Returns:
Type | Description |
---|---|
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Step parameters when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
"""Step parameters when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
entrypoint(params, context)
staticmethod
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetRegistrationParameters |
Step parameters. |
required |
context |
StepContext |
Step context. |
required |
Returns:
Type | Description |
---|---|
str |
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
params: LabelStudioDatasetRegistrationParameters,
context: StepContext,
) -> str:
"""Gets preexisting dataset or creates a new one.
Args:
params: Step parameters.
context: Step context.
Returns:
The dataset name.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = context.stack.annotator # type: ignore[union-attr]
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator and annotator._connection_available():
for dataset in annotator.get_datasets():
if dataset.get_params()["title"] == params.dataset_name:
return cast(str, dataset.get_params()["title"])
dataset = annotator.register_dataset_for_annotation(params)
return cast(str, dataset.get_params()["title"])
raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (BaseStep)
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
The URI of the data to sync. |
required | |
dataset_name |
The name of the dataset to sync to. |
required | |
predictions |
The predictions to sync. |
required | |
params |
The parameters for the sync. |
required | |
context |
The StepContext. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
PARAMETERS_CLASS (BaseParameters)
pydantic-model
Step parameters when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
"""Step parameters when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to.
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
entrypoint(uri, dataset_name, predictions, params, context)
staticmethod
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI of the data to sync. |
required |
dataset_name |
str |
The name of the dataset to sync to. |
required |
predictions |
List[Dict[str, Any]] |
The predictions to sync. |
required |
params |
LabelStudioDatasetSyncParameters |
The parameters for the sync. |
required |
context |
StepContext |
The StepContext. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
uri: str,
dataset_name: str,
predictions: List[Dict[str, Any]],
params: LabelStudioDatasetSyncParameters,
context: StepContext,
) -> None:
"""Syncs new data to Label Studio.
Args:
uri: The URI of the data to sync.
dataset_name: The name of the dataset to sync to.
predictions: The predictions to sync.
params: The parameters for the sync.
context: The StepContext.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
ValueError: if you are trying to sync from outside ZenML.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = context.stack.annotator # type: ignore[union-attr]
artifact_store = context.stack.artifact_store # type: ignore[union-attr]
secrets_manager = context.stack.secrets_manager # type: ignore[union-attr]
if not annotator or not artifact_store or not secrets_manager:
raise StackComponentInterfaceError(
"An active annotator, artifact store and secrets manager are required to run this step."
)
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
# TODO: check that annotator is connected before querying it
dataset = annotator.get_dataset(dataset_name=dataset_name)
if not uri.startswith(artifact_store.path):
raise ValueError(
"ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
)
# removes the initial forward slash from the prefix attribute by slicing
params.prefix = urlparse(uri).path.lstrip("/")
base_uri = urlparse(uri).netloc
# gets the secret used for authentication
if params.storage_type == "azure":
if not isinstance(artifact_store, AuthenticationMixin):
raise TypeError(
"The artifact store must inherit from "
f"{AuthenticationMixin.__name__} to work with a Label Studio "
f"`{params.storage_type}` storage."
)
azure_secret = artifact_store.get_authentication_secret(
expected_schema_type=AzureSecretSchema
)
if not azure_secret:
raise ValueError(
"Missing secret to authenticate cloud storage for Label Studio."
)
params.azure_account_name = azure_secret.account_name
params.azure_account_key = azure_secret.account_key
elif params.storage_type == "gcs":
if not isinstance(artifact_store, AuthenticationMixin):
raise TypeError(
"The artifact store must inherit from "
f"{AuthenticationMixin.__name__} to work with a Label Studio "
f"`{params.storage_type}` storage."
)
gcp_secret = artifact_store.get_authentication_secret(
expected_schema_type=GCPSecretSchema
)
if not gcp_secret:
raise ValueError(
"Missing secret to authenticate cloud storage for Label Studio."
)
params.google_application_credentials = gcp_secret.token
elif params.storage_type == "s3":
aws_secret = secrets_manager.get_secret(LABEL_STUDIO_AWS_SECRET_NAME)
if not isinstance(aws_secret, AWSSecretSchema):
raise TypeError(
f"The secret `{LABEL_STUDIO_AWS_SECRET_NAME}` needs to be "
f"an `aws` schema secret."
)
params.aws_access_key_id = aws_secret.aws_access_key_id
params.aws_secret_access_key = aws_secret.aws_secret_access_key
params.aws_session_token = aws_secret.aws_session_token
if annotator and annotator._connection_available():
# TODO: get existing (CHECK!) or create the sync connection
annotator.connect_and_sync_external_storage(
uri=base_uri,
params=params,
dataset=dataset,
)
if predictions:
filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
params.label_config_type
]
preds_with_task_ids = convert_pred_filenames_to_task_ids(
predictions,
dataset.tasks,
filename_reference,
params.storage_type,
)
# TODO: filter out any predictions that exist + have already been
# made (maybe?). Only pass in preds for tasks without pre-annotations.
dataset.create_predictions(preds_with_task_ids)
else:
raise StackComponentInterfaceError("No active annotator.")