Databricks
zenml.integrations.databricks
special
Initialization of the Databricks integration for ZenML.
DatabricksIntegration (Integration)
Definition of Databricks Integration for ZenML.
Source code in zenml/integrations/databricks/__init__.py
class DatabricksIntegration(Integration):
"""Definition of Databricks Integration for ZenML."""
NAME = DATABRICKS
REQUIREMENTS = ["databricks-sdk==0.28.0"]
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["numpy", "pandas"]
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
"""Method to get the requirements for the integration.
Args:
target_os: The target operating system to get the requirements for.
Returns:
A list of requirements.
"""
from zenml.integrations.numpy import NumpyIntegration
from zenml.integrations.pandas import PandasIntegration
return cls.REQUIREMENTS + \
NumpyIntegration.get_requirements(target_os=target_os) + \
PandasIntegration.get_requirements(target_os=target_os)
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Databricks integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.databricks.flavors import (
DatabricksOrchestratorFlavor,
DatabricksModelDeployerFlavor,
)
return [
DatabricksOrchestratorFlavor,
DatabricksModelDeployerFlavor,
]
flavors()
classmethod
Declare the stack component flavors for the Databricks integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/databricks/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Databricks integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.databricks.flavors import (
DatabricksOrchestratorFlavor,
DatabricksModelDeployerFlavor,
)
return [
DatabricksOrchestratorFlavor,
DatabricksModelDeployerFlavor,
]
get_requirements(target_os=None)
classmethod
Method to get the requirements for the integration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_os |
Optional[str] |
The target operating system to get the requirements for. |
None |
Returns:
Type | Description |
---|---|
List[str] |
A list of requirements. |
Source code in zenml/integrations/databricks/__init__.py
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
"""Method to get the requirements for the integration.
Args:
target_os: The target operating system to get the requirements for.
Returns:
A list of requirements.
"""
from zenml.integrations.numpy import NumpyIntegration
from zenml.integrations.pandas import PandasIntegration
return cls.REQUIREMENTS + \
NumpyIntegration.get_requirements(target_os=target_os) + \
PandasIntegration.get_requirements(target_os=target_os)
flavors
special
Databricks integration flavors.
databricks_model_deployer_flavor
Databricks model deployer flavor.
DatabricksBaseConfig (BaseModel)
Databricks Inference Endpoint configuration.
Source code in zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py
class DatabricksBaseConfig(BaseModel):
"""Databricks Inference Endpoint configuration."""
workload_size: str
scale_to_zero_enabled: bool = False
env_vars: Optional[Dict[str, str]] = None
workload_type: Optional[str] = None
endpoint_secret_name: Optional[str] = None
DatabricksModelDeployerConfig (BaseModelDeployerConfig)
Configuration for the Databricks model deployer.
Attributes:
Name | Type | Description |
---|---|---|
host |
str |
Databricks host. |
secret_name |
Optional[str] |
Secret name to use for authentication. |
client_id |
Optional[str] |
Databricks client id. |
client_secret |
Optional[str] |
Databricks client secret. |
Source code in zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py
class DatabricksModelDeployerConfig(BaseModelDeployerConfig):
"""Configuration for the Databricks model deployer.
Attributes:
host: Databricks host.
secret_name: Secret name to use for authentication.
client_id: Databricks client id.
client_secret: Databricks client secret.
"""
host: str
secret_name: Optional[str] = None
client_id: Optional[str] = SecretField(default=None)
client_secret: Optional[str] = SecretField(default=None)
DatabricksModelDeployerFlavor (BaseModelDeployerFlavor)
Databricks Endpoint model deployer flavor.
Source code in zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py
class DatabricksModelDeployerFlavor(BaseModelDeployerFlavor):
"""Databricks Endpoint model deployer flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return DATABRICKS_MODEL_DEPLOYER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/databricks.png"
@property
def config_class(self) -> Type[DatabricksModelDeployerConfig]:
"""Returns `DatabricksModelDeployerConfig` config class.
Returns:
The config class.
"""
return DatabricksModelDeployerConfig
@property
def implementation_class(self) -> Type["DatabricksModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
DatabricksModelDeployer,
)
return DatabricksModelDeployer
config_class: Type[zenml.integrations.databricks.flavors.databricks_model_deployer_flavor.DatabricksModelDeployerConfig]
property
readonly
Returns DatabricksModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.databricks.flavors.databricks_model_deployer_flavor.DatabricksModelDeployerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[DatabricksModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[DatabricksModelDeployer] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
databricks_orchestrator_flavor
Databricks orchestrator base config and settings.
DatabricksAvailabilityType (StrEnum)
Databricks availability type.
Source code in zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py
class DatabricksAvailabilityType(StrEnum):
"""Databricks availability type."""
ON_DEMAND = "ON_DEMAND"
SPOT = "SPOT"
SPOT_WITH_FALLBACK = "SPOT_WITH_FALLBACK"
DatabricksOrchestratorConfig (BaseOrchestratorConfig, DatabricksOrchestratorSettings)
Databricks orchestrator base config.
Attributes:
Name | Type | Description |
---|---|---|
host |
str |
Databricks host. |
client_id |
str |
Databricks client id. |
client_secret |
str |
Databricks client secret. |
Source code in zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py
class DatabricksOrchestratorConfig(
BaseOrchestratorConfig, DatabricksOrchestratorSettings
):
"""Databricks orchestrator base config.
Attributes:
host: Databricks host.
client_id: Databricks client id.
client_secret: Databricks client secret.
"""
host: str
client_id: str = SecretField(default=None)
client_secret: str = SecretField(default=None)
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
Returns:
True if this config is for a local component, False otherwise.
"""
return False
@property
def is_remote(self) -> bool:
"""Checks if this stack component is running remotely.
Returns:
True if this config is for a remote component, False otherwise.
"""
return True
@property
def is_schedulable(self) -> bool:
"""Whether the orchestrator is schedulable or not.
Returns:
Whether the orchestrator is schedulable or not.
"""
return True
is_local: bool
property
readonly
Checks if this stack component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
is_remote: bool
property
readonly
Checks if this stack component is running remotely.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a remote component, False otherwise. |
is_schedulable: bool
property
readonly
Whether the orchestrator is schedulable or not.
Returns:
Type | Description |
---|---|
bool |
Whether the orchestrator is schedulable or not. |
DatabricksOrchestratorFlavor (BaseOrchestratorFlavor)
Databricks orchestrator flavor.
Source code in zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py
class DatabricksOrchestratorFlavor(BaseOrchestratorFlavor):
"""Databricks orchestrator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return DATABRICKS_ORCHESTRATOR_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/databricks.png"
@property
def config_class(self) -> Type[DatabricksOrchestratorConfig]:
"""Returns `KubeflowOrchestratorConfig` config class.
Returns:
The config class.
"""
return DatabricksOrchestratorConfig
@property
def implementation_class(self) -> Type["DatabricksOrchestrator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.databricks.orchestrators import (
DatabricksOrchestrator,
)
return DatabricksOrchestrator
config_class: Type[zenml.integrations.databricks.flavors.databricks_orchestrator_flavor.DatabricksOrchestratorConfig]
property
readonly
Returns KubeflowOrchestratorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.databricks.flavors.databricks_orchestrator_flavor.DatabricksOrchestratorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[DatabricksOrchestrator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[DatabricksOrchestrator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
DatabricksOrchestratorSettings (BaseSettings)
Databricks orchestrator base settings.
Attributes:
Name | Type | Description |
---|---|---|
spark_version |
Optional[str] |
Spark version. |
num_workers |
Optional[int] |
Number of workers. |
node_type_id |
Optional[str] |
Node type id. |
policy_id |
Optional[str] |
Policy id. |
autotermination_minutes |
Optional[int] |
Autotermination minutes. |
autoscale |
Tuple[int, int] |
Autoscale. |
single_user_name |
Optional[str] |
Single user name. |
spark_conf |
Optional[Dict[str, str]] |
Spark configuration. |
spark_env_vars |
Optional[Dict[str, str]] |
Spark environment variables. |
schedule_timezone |
Optional[str] |
Schedule timezone. |
Source code in zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py
class DatabricksOrchestratorSettings(BaseSettings):
"""Databricks orchestrator base settings.
Attributes:
spark_version: Spark version.
num_workers: Number of workers.
node_type_id: Node type id.
policy_id: Policy id.
autotermination_minutes: Autotermination minutes.
autoscale: Autoscale.
single_user_name: Single user name.
spark_conf: Spark configuration.
spark_env_vars: Spark environment variables.
schedule_timezone: Schedule timezone.
"""
# Resources
spark_version: Optional[str] = None
num_workers: Optional[int] = None
node_type_id: Optional[str] = None
policy_id: Optional[str] = None
autotermination_minutes: Optional[int] = None
autoscale: Tuple[int, int] = (0, 1)
single_user_name: Optional[str] = None
spark_conf: Optional[Dict[str, str]] = None
spark_env_vars: Optional[Dict[str, str]] = None
schedule_timezone: Optional[str] = None
availability_type: Optional[DatabricksAvailabilityType] = None
model_deployers
special
Initialization of the Databricks model deployers.
databricks_model_deployer
Implementation of the Databricks Model Deployer.
DatabricksModelDeployer (BaseModelDeployer)
Databricks endpoint model deployer.
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
class DatabricksModelDeployer(BaseModelDeployer):
"""Databricks endpoint model deployer."""
NAME: ClassVar[str] = "Databricks"
FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = (
DatabricksModelDeployerFlavor
)
@property
def config(self) -> DatabricksModelDeployerConfig:
"""Config class for the Databricks Model deployer settings class.
Returns:
The configuration.
"""
return cast(DatabricksModelDeployerConfig, self._config)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Returns:
A validator that checks that the stack contains a remote artifact
store.
"""
def _validate_if_secret_or_token_is_present(
stack: "Stack",
) -> Tuple[bool, str]:
"""Check if client id and client secret or secret name is present in the stack.
Args:
stack: The stack to validate.
Returns:
A tuple with a boolean indicating whether the stack is valid
and a message describing the validation result.
"""
return bool(
(self.config.client_id and self.config.client_secret)
or self.config.secret_name
), (
"The Databricks model deployer requires either a secret name"
" or a client id and client secret to be present in the stack."
)
return StackValidator(
custom_validation_function=_validate_if_secret_or_token_is_present,
)
def _create_new_service(
self, id: UUID, timeout: int, config: DatabricksDeploymentConfig
) -> DatabricksDeploymentService:
"""Creates a new DatabricksDeploymentService.
Args:
id: the UUID of the model to be deployed with Databricks model deployer.
timeout: the timeout in seconds to wait for the Databricks inference endpoint
to be provisioned and successfully started or updated.
config: the configuration of the model to be deployed with Databricks model deployer.
Returns:
The DatabricksDeploymentConfig object that can be used to interact
with the Databricks inference endpoint.
"""
# create a new service for the new model
service = DatabricksDeploymentService(uuid=id, config=config)
logger.info(
f"Creating an artifact {DATABRICKS_SERVICE_ARTIFACT} with service instance attached as metadata."
" If there's an active pipeline and/or model this artifact will be associated with it."
)
service.start(timeout=timeout)
return service
def _clean_up_existing_service(
self,
timeout: int,
force: bool,
existing_service: DatabricksDeploymentService,
) -> None:
"""Stop existing services.
Args:
timeout: the timeout in seconds to wait for the Databricks
deployment to be stopped.
force: if True, force the service to stop
existing_service: Existing Databricks deployment service
"""
# stop the older service
existing_service.stop(timeout=timeout, force=force)
def perform_deploy_model(
self,
id: UUID,
config: ServiceConfig,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new Databricks deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
Args:
id: the UUID of the model to be deployed with Databricks.
config: the configuration of the model to be deployed with Databricks.
timeout: the timeout in seconds to wait for the Databricks endpoint
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the Databricks
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML Databricks deployment service object that can be used to
interact with the remote Databricks inference endpoint server.
"""
with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler:
config = cast(DatabricksDeploymentConfig, config)
# create a new DatabricksDeploymentService instance
service = self._create_new_service(
id=id, timeout=timeout, config=config
)
logger.info(
f"Creating a new Databricks inference endpoint service: {service}"
)
# Add telemetry with metadata that gets the stack metadata and
# differentiates between pure model and custom code deployments
stack = Client().active_stack
stack_metadata = {
component_type.value: component.flavor
for component_type, component in stack.components.items()
}
analytics_handler.metadata = {
"store_type": Client().zen_store.type.value,
**stack_metadata,
}
return service
def perform_stop_model(
self,
service: BaseService,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> BaseService:
"""Method to stop a model server.
Args:
service: The service to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
Returns:
The stopped service.
"""
service.stop(timeout=timeout, force=force)
return service
def perform_start_model(
self,
service: BaseService,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Method to start a model server.
Args:
service: The service to start.
timeout: Timeout in seconds to wait for the service to start.
Returns:
The started service.
"""
service.start(timeout=timeout)
return service
def perform_delete_model(
self,
service: BaseService,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
service: The service to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
service = cast(DatabricksDeploymentService, service)
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "DatabricksDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information that might be relevant to the user.
Args:
service_instance: Instance of a DatabricksDeploymentService
Returns:
Model server information.
"""
return {
"PREDICTION_URL": service_instance.get_prediction_url(),
"HEALTH_CHECK_URL": service_instance.get_healthcheck_url(),
}
config: DatabricksModelDeployerConfig
property
readonly
Config class for the Databricks Model deployer settings class.
Returns:
Type | Description |
---|---|
DatabricksModelDeployerConfig |
The configuration. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a remote artifact store. |
FLAVOR (BaseModelDeployerFlavor)
Databricks Endpoint model deployer flavor.
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
class DatabricksModelDeployerFlavor(BaseModelDeployerFlavor):
"""Databricks Endpoint model deployer flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return DATABRICKS_MODEL_DEPLOYER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/databricks.png"
@property
def config_class(self) -> Type[DatabricksModelDeployerConfig]:
"""Returns `DatabricksModelDeployerConfig` config class.
Returns:
The config class.
"""
return DatabricksModelDeployerConfig
@property
def implementation_class(self) -> Type["DatabricksModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
DatabricksModelDeployer,
)
return DatabricksModelDeployer
config_class: Type[zenml.integrations.databricks.flavors.databricks_model_deployer_flavor.DatabricksModelDeployerConfig]
property
readonly
Returns DatabricksModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.databricks.flavors.databricks_model_deployer_flavor.DatabricksModelDeployerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[DatabricksModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[DatabricksModelDeployer] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
get_model_server_info(service_instance)
staticmethod
Return implementation specific information that might be relevant to the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
DatabricksDeploymentService |
Instance of a DatabricksDeploymentService |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
Model server information. |
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "DatabricksDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information that might be relevant to the user.
Args:
service_instance: Instance of a DatabricksDeploymentService
Returns:
Model server information.
"""
return {
"PREDICTION_URL": service_instance.get_prediction_url(),
"HEALTH_CHECK_URL": service_instance.get_healthcheck_url(),
}
perform_delete_model(self, service, timeout=300, force=False)
Method to delete all configuration of a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to delete. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
300 |
force |
bool |
If True, force the service to stop. |
False |
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
def perform_delete_model(
self,
service: BaseService,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
service: The service to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
service = cast(DatabricksDeploymentService, service)
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
perform_deploy_model(self, id, config, timeout=300)
Create a new Databricks deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
UUID |
the UUID of the model to be deployed with Databricks. |
required |
config |
ServiceConfig |
the configuration of the model to be deployed with Databricks. |
required |
timeout |
int |
the timeout in seconds to wait for the Databricks endpoint to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Databricks server is provisioned, without waiting for it to fully start. |
300 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML Databricks deployment service object that can be used to interact with the remote Databricks inference endpoint server. |
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
def perform_deploy_model(
self,
id: UUID,
config: ServiceConfig,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new Databricks deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
Args:
id: the UUID of the model to be deployed with Databricks.
config: the configuration of the model to be deployed with Databricks.
timeout: the timeout in seconds to wait for the Databricks endpoint
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the Databricks
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML Databricks deployment service object that can be used to
interact with the remote Databricks inference endpoint server.
"""
with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler:
config = cast(DatabricksDeploymentConfig, config)
# create a new DatabricksDeploymentService instance
service = self._create_new_service(
id=id, timeout=timeout, config=config
)
logger.info(
f"Creating a new Databricks inference endpoint service: {service}"
)
# Add telemetry with metadata that gets the stack metadata and
# differentiates between pure model and custom code deployments
stack = Client().active_stack
stack_metadata = {
component_type.value: component.flavor
for component_type, component in stack.components.items()
}
analytics_handler.metadata = {
"store_type": Client().zen_store.type.value,
**stack_metadata,
}
return service
perform_start_model(self, service, timeout=300)
Method to start a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to start. |
required |
timeout |
int |
Timeout in seconds to wait for the service to start. |
300 |
Returns:
Type | Description |
---|---|
BaseService |
The started service. |
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
def perform_start_model(
self,
service: BaseService,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
"""Method to start a model server.
Args:
service: The service to start.
timeout: Timeout in seconds to wait for the service to start.
Returns:
The started service.
"""
service.start(timeout=timeout)
return service
perform_stop_model(self, service, timeout=300, force=False)
Method to stop a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to stop. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
300 |
force |
bool |
If True, force the service to stop. |
False |
Returns:
Type | Description |
---|---|
BaseService |
The stopped service. |
Source code in zenml/integrations/databricks/model_deployers/databricks_model_deployer.py
def perform_stop_model(
self,
service: BaseService,
timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT,
force: bool = False,
) -> BaseService:
"""Method to stop a model server.
Args:
service: The service to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
Returns:
The stopped service.
"""
service.stop(timeout=timeout, force=force)
return service
orchestrators
special
Initialization of the Databricks ZenML orchestrator.
databricks_orchestrator
Implementation of the Databricks orchestrator.
DatabricksOrchestrator (WheeledOrchestrator)
Base class for Orchestrator responsible for running pipelines remotely in a VM.
This orchestrator does not support running on a schedule.
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator.py
class DatabricksOrchestrator(WheeledOrchestrator):
"""Base class for Orchestrator responsible for running pipelines remotely in a VM.
This orchestrator does not support running on a schedule.
"""
# The default instance type to use if none is specified in settings
DEFAULT_INSTANCE_TYPE: Optional[str] = None
@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
In the remote case, checks that the stack contains a container registry,
image builder and only remote components.
Returns:
A `StackValidator` instance.
"""
def _validate_remote_components(
stack: "Stack",
) -> Tuple[bool, str]:
for component in stack.components.values():
if not component.config.is_local:
continue
return False, (
f"The Databricks orchestrator runs pipelines remotely, "
f"but the '{component.name}' {component.type.value} is "
"a local stack component and will not be available in "
"the Databricks step.\nPlease ensure that you always "
"use non-local stack components with the Databricks "
"orchestrator."
)
return True, ""
return StackValidator(
custom_validation_function=_validate_remote_components,
)
def _get_databricks_client(
self,
) -> DatabricksClient:
"""Creates a Databricks client.
Returns:
The Databricks client.
"""
return DatabricksClient(
host=self.config.host,
client_id=self.config.client_id,
client_secret=self.config.client_secret,
)
@property
def config(self) -> DatabricksOrchestratorConfig:
"""Returns the `DatabricksOrchestratorConfig` config.
Returns:
The configuration.
"""
return cast(DatabricksOrchestratorConfig, self._config)
@property
def settings_class(self) -> Type[DatabricksOrchestratorSettings]:
"""Settings class for the Databricks orchestrator.
Returns:
The settings class.
"""
return DatabricksOrchestratorSettings
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If no run id exists. This happens when this method
gets called while the orchestrator is not running a pipeline.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID}."
)
@property
def root_directory(self) -> str:
"""Path to the root directory for all files concerning this orchestrator.
Returns:
Path to the root directory.
"""
return os.path.join(
io_utils.get_global_config_directory(),
"databricks",
str(self.id),
)
@property
def pipeline_directory(self) -> str:
"""Returns path to a directory in which the kubeflow pipeline files are stored.
Returns:
Path to the pipeline directory.
"""
return os.path.join(self.root_directory, "pipelines")
def setup_credentials(self) -> None:
"""Set up credentials for the orchestrator."""
connector = self.get_connector()
assert connector is not None
connector.configure_local_client()
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Creates a wheel and uploads the pipeline to Databricks.
This functions as an intermediary representation of the pipeline which
is then deployed to the kubeflow pipelines instance.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()`
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
task for each step (`_construct_databricks_pipeline`).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these task are then also configured onto each
task by pointing at the downstream steps.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
ValueError: If the schedule is not set or if the cron expression
is not set.
"""
settings = cast(
DatabricksOrchestratorSettings, self.get_settings(deployment)
)
if deployment.schedule:
if (
deployment.schedule.catchup
or deployment.schedule.interval_second
):
logger.warning(
"Databricks orchestrator only uses schedules with the "
"`cron_expression` property, with optional `start_time` and/or `end_time`. "
"All other properties are ignored."
)
if deployment.schedule.cron_expression is None:
raise ValueError(
"Property `cron_expression` must be set when passing "
"schedule to a Databricks orchestrator."
)
if (
deployment.schedule.cron_expression
and settings.schedule_timezone is None
):
raise ValueError(
"Property `schedule_timezone` must be set when passing "
"`cron_expression` to a Databricks orchestrator."
"Databricks orchestrator requires a Java Timezone ID to run the pipeline on schedule."
"Please refer to https://docs.oracle.com/middleware/1221/wcs/tag-ref/MISC/TimeZones.html for more information."
)
# Get deployment id
deployment_id = deployment.id
# Create a callable for future compilation into a dsl.Pipeline.
def _construct_databricks_pipeline(
zenml_project_wheel: str, job_cluster_key: str
) -> List[DatabricksTask]:
"""Create a databrcks task for each step.
This should contain the name of the step or task and configures the
entrypoint of the task to run the step.
Additionally, this gives each task information about its
direct downstream steps.
Args:
zenml_project_wheel: The wheel package containing the ZenML
project.
job_cluster_key: The ID of the Databricks job cluster.
Returns:
A list of Databricks tasks.
"""
tasks = []
for step_name, step in deployment.step_configurations.items():
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = DatabricksEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
deployment_id=deployment_id,
wheel_package=self.package_name,
databricks_job_id=DATABRICKS_JOB_ID_PARAMETER_REFERENCE,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
upstream_steps = [
f"{deployment_id}_{upstream_step_name}"
for upstream_step_name in step.spec.upstream_steps
]
docker_settings = step.config.docker_settings
docker_image_builder = PipelineDockerImageBuilder()
# Gather the requirements files
requirements_files = (
docker_image_builder.gather_requirements_files(
docker_settings=docker_settings,
stack=Client().active_stack,
log=False,
)
)
# Extract and clean the requirements
requirements = list(
itertools.chain.from_iterable(
r[1].strip().split("\n") for r in requirements_files
)
)
# Remove empty items and duplicates
requirements = sorted(set(filter(None, requirements)))
task = convert_step_to_task(
f"{deployment_id}_{step_name}",
ZENML_STEP_DEFAULT_ENTRYPOINT_COMMAND,
arguments,
clean_requirements(requirements),
depends_on=upstream_steps,
zenml_project_wheel=zenml_project_wheel,
job_cluster_key=job_cluster_key,
)
tasks.append(task)
return tasks
# Get the orchestrator run name
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# Get a filepath to use to save the finished yaml to
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{orchestrator_run_name}.yaml"
)
# Copy the repository to a temporary directory and add a setup.py file
repository_temp_dir = (
self.copy_repository_to_temp_dir_and_add_setup_py()
)
# Create a wheel for the package in the temporary directory
wheel_path = self.create_wheel(temp_dir=repository_temp_dir)
databricks_client = self._get_databricks_client()
# Create an empty folder in a volume.
deployment_name = (
deployment.pipeline.name if deployment.pipeline else "default"
)
databricks_directory = f"{DATABRICKS_WHEELS_DIRECTORY_PREFIX}/{deployment_name}/{orchestrator_run_name}"
databricks_wheel_path = (
f"{databricks_directory}/{wheel_path.rsplit('/', 1)[-1]}"
)
databricks_client.dbutils.fs.mkdirs(databricks_directory)
databricks_client.dbutils.fs.cp(
f"{DATABRICKS_LOCAL_FILESYSTEM_PREFIX}/{wheel_path}",
databricks_wheel_path,
)
# Construct the env variables for the pipeline
env_vars = environment.copy()
spark_env_vars = settings.spark_env_vars
if spark_env_vars:
for key, value in spark_env_vars.items():
env_vars[key] = value
env_vars[ENV_ZENML_CUSTOM_SOURCE_ROOT] = (
DATABRICKS_ZENML_DEFAULT_CUSTOM_REPOSITORY_PATH
)
fileio.rmtree(repository_temp_dir)
logger.info(
"Writing Databricks workflow definition to `%s`.",
pipeline_file_path,
)
# using the databricks client uploads the pipeline to databricks
job_cluster_key = self.sanitize_name(f"{deployment_id}")
self._upload_and_run_pipeline(
pipeline_name=orchestrator_run_name,
settings=settings,
tasks=_construct_databricks_pipeline(
databricks_wheel_path, job_cluster_key
),
env_vars=env_vars,
job_cluster_key=job_cluster_key,
schedule=deployment.schedule,
)
def _upload_and_run_pipeline(
self,
pipeline_name: str,
settings: DatabricksOrchestratorSettings,
tasks: List[DatabricksTask],
env_vars: Dict[str, str],
job_cluster_key: str,
schedule: Optional["ScheduleResponse"] = None,
) -> None:
"""Uploads and run the pipeline on the Databricks jobs.
Args:
pipeline_name: The name of the pipeline.
tasks: The list of tasks to run.
env_vars: The environment variables.
job_cluster_key: The ID of the Databricks job cluster.
schedule: The schedule to run the pipeline
settings: The settings for the Databricks orchestrator.
Raises:
ValueError: If the `Job Compute` policy is not found.
ValueError: If the `schedule_timezone` is not set when passing
"""
databricks_client = self._get_databricks_client()
spark_conf = settings.spark_conf or {}
spark_conf[
"spark.databricks.driver.dbfsLibraryInstallationAllowed"
] = "true"
policy_id = settings.policy_id or None
for policy in databricks_client.cluster_policies.list():
if policy.name == "Job Compute":
policy_id = policy.policy_id
if policy_id is None:
raise ValueError(
"Could not find the `Job Compute` policy in Databricks."
)
job_cluster = JobCluster(
job_cluster_key=job_cluster_key,
new_cluster=ClusterSpec(
spark_version=settings.spark_version
or DATABRICKS_SPARK_DEFAULT_VERSION,
num_workers=settings.num_workers,
node_type_id=settings.node_type_id or "Standard_D4s_v5",
policy_id=policy_id,
autoscale=AutoScale(
min_workers=settings.autoscale[0],
max_workers=settings.autoscale[1],
),
single_user_name=settings.single_user_name,
spark_env_vars=env_vars,
spark_conf=spark_conf,
workload_type=WorkloadType(
clients=ClientsTypes(jobs=True, notebooks=False)
),
),
)
if schedule and schedule.cron_expression:
schedule_timezone = settings.schedule_timezone
if schedule_timezone:
databricks_schedule = CronSchedule(
quartz_cron_expression=schedule.cron_expression,
timezone_id=schedule_timezone,
)
else:
raise ValueError(
"Property `schedule_timezone` must be set when passing "
"`cron_expression` to a Databricks orchestrator. "
"Databricks orchestrator requires a Java Timezone ID to run the pipeline on schedule. "
"Please refer to https://docs.oracle.com/middleware/1221/wcs/tag-ref/MISC/TimeZones.html for more information."
)
else:
databricks_schedule = None
job = databricks_client.jobs.create(
name=pipeline_name,
tasks=tasks,
job_clusters=[job_cluster],
schedule=databricks_schedule,
)
if job.job_id:
databricks_client.jobs.run_now(job_id=job.job_id)
else:
raise ValueError("An error occurred while getting the job id.")
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_url = (
f"{self.config.host}/jobs/" f"{self.get_orchestrator_run_id()}"
)
return {
METADATA_ORCHESTRATOR_URL: Uri(run_url),
}
config: DatabricksOrchestratorConfig
property
readonly
Returns the DatabricksOrchestratorConfig
config.
Returns:
Type | Description |
---|---|
DatabricksOrchestratorConfig |
The configuration. |
pipeline_directory: str
property
readonly
Returns path to a directory in which the kubeflow pipeline files are stored.
Returns:
Type | Description |
---|---|
str |
Path to the pipeline directory. |
root_directory: str
property
readonly
Path to the root directory for all files concerning this orchestrator.
Returns:
Type | Description |
---|---|
str |
Path to the root directory. |
settings_class: Type[zenml.integrations.databricks.flavors.databricks_orchestrator_flavor.DatabricksOrchestratorSettings]
property
readonly
Settings class for the Databricks orchestrator.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.databricks.flavors.databricks_orchestrator_flavor.DatabricksOrchestratorSettings] |
The settings class. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates the stack.
In the remote case, checks that the stack contains a container registry, image builder and only remote components.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A |
get_orchestrator_run_id(self)
Returns the active orchestrator run id.
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no run id exists. This happens when this method gets called while the orchestrator is not running a pipeline. |
Returns:
Type | Description |
---|---|
str |
The orchestrator run id. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the run id cannot be read from the environment. |
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator.py
def get_orchestrator_run_id(self) -> str:
"""Returns the active orchestrator run id.
Raises:
RuntimeError: If no run id exists. This happens when this method
gets called while the orchestrator is not running a pipeline.
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
try:
return os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID}."
)
get_pipeline_run_metadata(self, run_id)
Get general component-specific metadata for a pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_id |
UUID |
The ID of the pipeline run. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator.py
def get_pipeline_run_metadata(
self, run_id: UUID
) -> Dict[str, "MetadataType"]:
"""Get general component-specific metadata for a pipeline run.
Args:
run_id: The ID of the pipeline run.
Returns:
A dictionary of metadata.
"""
run_url = (
f"{self.config.host}/jobs/" f"{self.get_orchestrator_run_id()}"
)
return {
METADATA_ORCHESTRATOR_URL: Uri(run_url),
}
prepare_or_run_pipeline(self, deployment, stack, environment)
Creates a wheel and uploads the pipeline to Databricks.
This functions as an intermediary representation of the pipeline which is then deployed to the kubeflow pipelines instance.
How it works:
Before this method is called the prepare_pipeline_deployment()
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
task for each step (_construct_databricks_pipeline
).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these task are then also configured onto each
task by pointing at the downstream steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeploymentResponse |
The pipeline deployment to prepare or run. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
environment |
Dict[str, str] |
Environment variables to set in the orchestration environment. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the schedule is not set or if the cron expression is not set. |
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator.py
def prepare_or_run_pipeline(
self,
deployment: "PipelineDeploymentResponse",
stack: "Stack",
environment: Dict[str, str],
) -> Any:
"""Creates a wheel and uploads the pipeline to Databricks.
This functions as an intermediary representation of the pipeline which
is then deployed to the kubeflow pipelines instance.
How it works:
-------------
Before this method is called the `prepare_pipeline_deployment()`
method builds a docker image that contains the code for the
pipeline, all steps the context around these files.
Based on this docker image a callable is created which builds
task for each step (`_construct_databricks_pipeline`).
To do this the entrypoint of the docker image is configured to
run the correct step within the docker image. The dependencies
between these task are then also configured onto each
task by pointing at the downstream steps.
Args:
deployment: The pipeline deployment to prepare or run.
stack: The stack the pipeline will run on.
environment: Environment variables to set in the orchestration
environment.
Raises:
ValueError: If the schedule is not set or if the cron expression
is not set.
"""
settings = cast(
DatabricksOrchestratorSettings, self.get_settings(deployment)
)
if deployment.schedule:
if (
deployment.schedule.catchup
or deployment.schedule.interval_second
):
logger.warning(
"Databricks orchestrator only uses schedules with the "
"`cron_expression` property, with optional `start_time` and/or `end_time`. "
"All other properties are ignored."
)
if deployment.schedule.cron_expression is None:
raise ValueError(
"Property `cron_expression` must be set when passing "
"schedule to a Databricks orchestrator."
)
if (
deployment.schedule.cron_expression
and settings.schedule_timezone is None
):
raise ValueError(
"Property `schedule_timezone` must be set when passing "
"`cron_expression` to a Databricks orchestrator."
"Databricks orchestrator requires a Java Timezone ID to run the pipeline on schedule."
"Please refer to https://docs.oracle.com/middleware/1221/wcs/tag-ref/MISC/TimeZones.html for more information."
)
# Get deployment id
deployment_id = deployment.id
# Create a callable for future compilation into a dsl.Pipeline.
def _construct_databricks_pipeline(
zenml_project_wheel: str, job_cluster_key: str
) -> List[DatabricksTask]:
"""Create a databrcks task for each step.
This should contain the name of the step or task and configures the
entrypoint of the task to run the step.
Additionally, this gives each task information about its
direct downstream steps.
Args:
zenml_project_wheel: The wheel package containing the ZenML
project.
job_cluster_key: The ID of the Databricks job cluster.
Returns:
A list of Databricks tasks.
"""
tasks = []
for step_name, step in deployment.step_configurations.items():
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = DatabricksEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
deployment_id=deployment_id,
wheel_package=self.package_name,
databricks_job_id=DATABRICKS_JOB_ID_PARAMETER_REFERENCE,
)
# Find the upstream container ops of the current step and
# configure the current container op to run after them
upstream_steps = [
f"{deployment_id}_{upstream_step_name}"
for upstream_step_name in step.spec.upstream_steps
]
docker_settings = step.config.docker_settings
docker_image_builder = PipelineDockerImageBuilder()
# Gather the requirements files
requirements_files = (
docker_image_builder.gather_requirements_files(
docker_settings=docker_settings,
stack=Client().active_stack,
log=False,
)
)
# Extract and clean the requirements
requirements = list(
itertools.chain.from_iterable(
r[1].strip().split("\n") for r in requirements_files
)
)
# Remove empty items and duplicates
requirements = sorted(set(filter(None, requirements)))
task = convert_step_to_task(
f"{deployment_id}_{step_name}",
ZENML_STEP_DEFAULT_ENTRYPOINT_COMMAND,
arguments,
clean_requirements(requirements),
depends_on=upstream_steps,
zenml_project_wheel=zenml_project_wheel,
job_cluster_key=job_cluster_key,
)
tasks.append(task)
return tasks
# Get the orchestrator run name
orchestrator_run_name = get_orchestrator_run_name(
pipeline_name=deployment.pipeline_configuration.name
)
# Get a filepath to use to save the finished yaml to
fileio.makedirs(self.pipeline_directory)
pipeline_file_path = os.path.join(
self.pipeline_directory, f"{orchestrator_run_name}.yaml"
)
# Copy the repository to a temporary directory and add a setup.py file
repository_temp_dir = (
self.copy_repository_to_temp_dir_and_add_setup_py()
)
# Create a wheel for the package in the temporary directory
wheel_path = self.create_wheel(temp_dir=repository_temp_dir)
databricks_client = self._get_databricks_client()
# Create an empty folder in a volume.
deployment_name = (
deployment.pipeline.name if deployment.pipeline else "default"
)
databricks_directory = f"{DATABRICKS_WHEELS_DIRECTORY_PREFIX}/{deployment_name}/{orchestrator_run_name}"
databricks_wheel_path = (
f"{databricks_directory}/{wheel_path.rsplit('/', 1)[-1]}"
)
databricks_client.dbutils.fs.mkdirs(databricks_directory)
databricks_client.dbutils.fs.cp(
f"{DATABRICKS_LOCAL_FILESYSTEM_PREFIX}/{wheel_path}",
databricks_wheel_path,
)
# Construct the env variables for the pipeline
env_vars = environment.copy()
spark_env_vars = settings.spark_env_vars
if spark_env_vars:
for key, value in spark_env_vars.items():
env_vars[key] = value
env_vars[ENV_ZENML_CUSTOM_SOURCE_ROOT] = (
DATABRICKS_ZENML_DEFAULT_CUSTOM_REPOSITORY_PATH
)
fileio.rmtree(repository_temp_dir)
logger.info(
"Writing Databricks workflow definition to `%s`.",
pipeline_file_path,
)
# using the databricks client uploads the pipeline to databricks
job_cluster_key = self.sanitize_name(f"{deployment_id}")
self._upload_and_run_pipeline(
pipeline_name=orchestrator_run_name,
settings=settings,
tasks=_construct_databricks_pipeline(
databricks_wheel_path, job_cluster_key
),
env_vars=env_vars,
job_cluster_key=job_cluster_key,
schedule=deployment.schedule,
)
setup_credentials(self)
Set up credentials for the orchestrator.
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator.py
def setup_credentials(self) -> None:
"""Set up credentials for the orchestrator."""
connector = self.get_connector()
assert connector is not None
connector.configure_local_client()
databricks_orchestrator_entrypoint_config
Entrypoint configuration for ZenML Databricks pipeline steps.
DatabricksEntrypointConfiguration (StepEntrypointConfiguration)
Entrypoint configuration for ZenML Databricks pipeline steps.
The only purpose of this entrypoint configuration is to reconstruct the environment variables that exceed the maximum length of 256 characters allowed for Databricks Processor steps from their individual components.
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py
class DatabricksEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for ZenML Databricks pipeline steps.
The only purpose of this entrypoint configuration is to reconstruct the
environment variables that exceed the maximum length of 256 characters
allowed for Databricks Processor steps from their individual components.
"""
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the wheel package.
"""
return (
super().get_entrypoint_options()
| {WHEEL_PACKAGE_OPTION}
| {DATABRICKS_JOB_ID_OPTION}
)
@classmethod
def get_entrypoint_arguments(
cls,
**kwargs: Any,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
The argument list should be something that
`argparse.ArgumentParser.parse_args(...)` can handle (e.g.
`["--some_option", "some_value"]` or `["--some_option=some_value"]`).
It needs to provide values for all options returned by the
`get_entrypoint_options()` method of this class.
Args:
**kwargs: Kwargs, must include the step name.
Returns:
The superclass arguments as well as arguments for the wheel package.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{WHEEL_PACKAGE_OPTION}",
kwargs[WHEEL_PACKAGE_OPTION],
f"--{DATABRICKS_JOB_ID_OPTION}",
kwargs[DATABRICKS_JOB_ID_OPTION],
]
def run(self) -> None:
"""Runs the step."""
# Get the wheel package and add it to the sys path
wheel_package = self.entrypoint_args[WHEEL_PACKAGE_OPTION]
distribution = pkg_resources.get_distribution(wheel_package)
project_root = os.path.join(distribution.location, wheel_package)
if project_root not in sys.path:
sys.path.insert(0, project_root)
sys.path.insert(-1, project_root)
# Get the job id and add it to the environment
databricks_job_id = self.entrypoint_args[DATABRICKS_JOB_ID_OPTION]
os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = (
databricks_job_id
)
# Run the step
super().run()
get_entrypoint_arguments(**kwargs)
classmethod
Gets all arguments that the entrypoint command should be called with.
The argument list should be something that
argparse.ArgumentParser.parse_args(...)
can handle (e.g.
["--some_option", "some_value"]
or ["--some_option=some_value"]
).
It needs to provide values for all options returned by the
get_entrypoint_options()
method of this class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Kwargs, must include the step name. |
{} |
Returns:
Type | Description |
---|---|
List[str] |
The superclass arguments as well as arguments for the wheel package. |
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py
@classmethod
def get_entrypoint_arguments(
cls,
**kwargs: Any,
) -> List[str]:
"""Gets all arguments that the entrypoint command should be called with.
The argument list should be something that
`argparse.ArgumentParser.parse_args(...)` can handle (e.g.
`["--some_option", "some_value"]` or `["--some_option=some_value"]`).
It needs to provide values for all options returned by the
`get_entrypoint_options()` method of this class.
Args:
**kwargs: Kwargs, must include the step name.
Returns:
The superclass arguments as well as arguments for the wheel package.
"""
return super().get_entrypoint_arguments(**kwargs) + [
f"--{WHEEL_PACKAGE_OPTION}",
kwargs[WHEEL_PACKAGE_OPTION],
f"--{DATABRICKS_JOB_ID_OPTION}",
kwargs[DATABRICKS_JOB_ID_OPTION],
]
get_entrypoint_options()
classmethod
Gets all options required for running with this configuration.
Returns:
Type | Description |
---|---|
Set[str] |
The superclass options as well as an option for the wheel package. |
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
"""Gets all options required for running with this configuration.
Returns:
The superclass options as well as an option for the wheel package.
"""
return (
super().get_entrypoint_options()
| {WHEEL_PACKAGE_OPTION}
| {DATABRICKS_JOB_ID_OPTION}
)
run(self)
Runs the step.
Source code in zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py
def run(self) -> None:
"""Runs the step."""
# Get the wheel package and add it to the sys path
wheel_package = self.entrypoint_args[WHEEL_PACKAGE_OPTION]
distribution = pkg_resources.get_distribution(wheel_package)
project_root = os.path.join(distribution.location, wheel_package)
if project_root not in sys.path:
sys.path.insert(0, project_root)
sys.path.insert(-1, project_root)
# Get the job id and add it to the environment
databricks_job_id = self.entrypoint_args[DATABRICKS_JOB_ID_OPTION]
os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = (
databricks_job_id
)
# Run the step
super().run()
services
special
Initialization of the Databricks Service.
databricks_deployment
Implementation of the Databricks Deployment service.
DatabricksDeploymentConfig (DatabricksBaseConfig, ServiceConfig)
Databricks service configurations.
Source code in zenml/integrations/databricks/services/databricks_deployment.py
class DatabricksDeploymentConfig(DatabricksBaseConfig, ServiceConfig):
"""Databricks service configurations."""
model_uri: Optional[str] = Field(
None,
description="URI of the model to deploy. This can be a local path or a cloud storage path.",
)
host: Optional[str] = Field(
None, description="Databricks host URL for the deployment."
)
def get_databricks_deployment_labels(self) -> Dict[str, str]:
"""Generate labels for the Databricks deployment from the service configuration.
These labels are attached to the Databricks deployment resource
and may be used as label selectors in lookup operations.
Returns:
The labels for the Databricks deployment.
"""
labels = {}
if self.pipeline_name:
labels["zenml_pipeline_name"] = self.pipeline_name
if self.pipeline_step_name:
labels["zenml_pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml_model_name"] = self.model_name
if self.model_uri:
labels["zenml_model_uri"] = self.model_uri
sanitize_labels(labels)
return labels
get_databricks_deployment_labels(self)
Generate labels for the Databricks deployment from the service configuration.
These labels are attached to the Databricks deployment resource and may be used as label selectors in lookup operations.
Returns:
Type | Description |
---|---|
Dict[str, str] |
The labels for the Databricks deployment. |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def get_databricks_deployment_labels(self) -> Dict[str, str]:
"""Generate labels for the Databricks deployment from the service configuration.
These labels are attached to the Databricks deployment resource
and may be used as label selectors in lookup operations.
Returns:
The labels for the Databricks deployment.
"""
labels = {}
if self.pipeline_name:
labels["zenml_pipeline_name"] = self.pipeline_name
if self.pipeline_step_name:
labels["zenml_pipeline_step_name"] = self.pipeline_step_name
if self.model_name:
labels["zenml_model_name"] = self.model_name
if self.model_uri:
labels["zenml_model_uri"] = self.model_uri
sanitize_labels(labels)
return labels
DatabricksDeploymentService (BaseDeploymentService)
Databricks model deployment service.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the Databricks deployment service class |
config |
DatabricksDeploymentConfig |
service configuration |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
class DatabricksDeploymentService(BaseDeploymentService):
"""Databricks model deployment service.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the Databricks deployment service class
config: service configuration
"""
SERVICE_TYPE = ServiceType(
name="databricks-deployment",
type="model-serving",
flavor="databricks",
description="Databricks inference endpoint prediction service",
)
config: DatabricksDeploymentConfig
status: DatabricksServiceStatus = Field(
default_factory=lambda: DatabricksServiceStatus()
)
def __init__(self, config: DatabricksDeploymentConfig, **attrs: Any):
"""Initialize the Databricks deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
super().__init__(config=config, **attrs)
def get_client_id_and_secret(self) -> Tuple[str, str, str]:
"""Get the Databricks client id and secret.
Raises:
ValueError: If client id and secret are not found.
Returns:
Databricks client id and secret.
"""
client = Client()
client_id = None
client_secret = None
host = None
from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
DatabricksModelDeployer,
)
model_deployer = client.active_stack.model_deployer
if not isinstance(model_deployer, DatabricksModelDeployer):
raise ValueError(
"DatabricksModelDeployer is not active in the stack."
)
host = model_deployer.config.host
self.config.host = host
if model_deployer.config.secret_name:
secret = client.get_secret(model_deployer.config.secret_name)
client_id = secret.secret_values["client_id"]
client_secret = secret.secret_values["client_secret"]
else:
client_id = model_deployer.config.client_id
client_secret = model_deployer.config.client_secret
if not client_id:
raise ValueError("Client id not found.")
if not client_secret:
raise ValueError("Client secret not found.")
if not host:
raise ValueError("Host not found.")
return host, client_id, client_secret
def _get_databricks_deployment_labels(self) -> Dict[str, str]:
"""Generate the labels for the Databricks deployment from the service configuration.
Returns:
The labels for the Databricks deployment.
"""
labels = self.config.get_databricks_deployment_labels()
labels["zenml_service_uuid"] = str(self.uuid)
sanitize_labels(labels)
return labels
@property
def databricks_client(self) -> DatabricksClient:
"""Get the deployed Databricks inference endpoint.
Returns:
databricks inference endpoint.
"""
return DatabricksClient(
host=self.get_client_id_and_secret()[0],
client_id=self.get_client_id_and_secret()[1],
client_secret=self.get_client_id_and_secret()[2],
)
@property
def databricks_endpoint(self) -> ServingEndpointDetailed:
"""Get the deployed Hugging Face inference endpoint.
Returns:
Databricks inference endpoint.
"""
return self.databricks_client.serving_endpoints.get(
name=self._generate_an_endpoint_name(),
)
@property
def prediction_url(self) -> Optional[str]:
"""The prediction URI exposed by the prediction service.
Returns:
The prediction URI exposed by the prediction service, or None if
the service is not yet ready.
"""
return f"{self.config.host}/serving-endpoints/{self._generate_an_endpoint_name()}/invocations"
def provision(self) -> None:
"""Provision or update remote Databricks deployment instance."""
from databricks.sdk.service.serving import (
ServedModelInputWorkloadSize,
ServedModelInputWorkloadType,
)
tags = []
for key, value in self._get_databricks_deployment_labels().items():
tags.append(EndpointTag(key=key, value=value))
# Attempt to create and wait for the inference endpoint
served_model = ServedModelInput(
model_name=self.config.model_name,
model_version=self.config.model_version,
scale_to_zero_enabled=self.config.scale_to_zero_enabled,
workload_type=ServedModelInputWorkloadType(
self.config.workload_type
),
workload_size=ServedModelInputWorkloadSize(
self.config.workload_size
),
)
databricks_endpoint = (
self.databricks_client.serving_endpoints.create_and_wait(
name=self._generate_an_endpoint_name(),
config=EndpointCoreConfigInput(
served_models=[served_model],
),
tags=tags,
)
)
# Check if the endpoint URL is available after provisioning
if databricks_endpoint.endpoint_url:
logger.info(
f"Databricks inference endpoint successfully deployed and available. Endpoint URL: {databricks_endpoint.endpoint_url}"
)
else:
logger.error(
"Failed to start Databricks inference endpoint service: No URL available, please check the Databricks console for more details."
)
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the the current operational state of the Databricks deployment.
Returns:
The operational state of the Databricks deployment and a message
providing additional information about that state (e.g. a
description of the error, if one is encountered).
"""
try:
status = self.databricks_endpoint.state or None
if (
status
and status.ready
and status.ready == EndpointStateReady.READY
):
return (ServiceState.ACTIVE, "")
elif (
status
and status.config_update
and status.config_update
== EndpointStateConfigUpdate.UPDATE_FAILED
):
return (
ServiceState.ERROR,
"Databricks Inference Endpoint deployment update failed",
)
elif (
status
and status.config_update
and status.config_update
== EndpointStateConfigUpdate.IN_PROGRESS
):
return (ServiceState.PENDING_STARTUP, "")
return (ServiceState.PENDING_STARTUP, "")
except Exception as e:
return (
ServiceState.INACTIVE,
f"Databricks Inference Endpoint deployment is inactive or not found: {e}",
)
def deprovision(self, force: bool = False) -> None:
"""Deprovision the remote Databricks deployment instance.
Args:
force: if True, the remote deployment instance will be
forcefully deprovisioned.
"""
try:
self.databricks_client.serving_endpoints.delete(
name=self._generate_an_endpoint_name()
)
except Exception:
logger.error(
"Databricks Inference Endpoint is deleted or cannot be found."
)
def predict(
self, request: Union["NDArray[Any]", pd.DataFrame]
) -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: The input data for the prediction.
Returns:
The prediction result.
Raises:
Exception: if the service is not running
ValueError: if the endpoint secret name is not provided.
"""
if not self.is_running:
raise Exception(
"Databricks endpoint inference service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is not None:
if not self.config.endpoint_secret_name:
raise ValueError(
"No endpoint secret name is provided for prediction."
)
databricks_token = Client().get_secret(
self.config.endpoint_secret_name
)
if not databricks_token.secret_values["token"]:
raise ValueError("No databricks token found.")
headers = {
"Authorization": f"Bearer {databricks_token.secret_values['token']}",
"Content-Type": "application/json",
}
if isinstance(request, pd.DataFrame):
response = requests.post( # nosec
self.prediction_url,
json={"instances": request.to_dict("records")},
headers=headers,
)
else:
response = requests.post( # nosec
self.prediction_url,
json={"instances": request.tolist()},
headers=headers,
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
return np.array(response.json()["predictions"])
def get_logs(
self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
"""Retrieve the service logs.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Yields:
A generator that can be accessed to get the service logs.
"""
logger.info(
"Databricks Endpoints provides access to the logs of your Endpoints through the UI in the `Logs` tab of your Endpoint"
)
def log_generator() -> Generator[str, bool, None]:
last_log_count = 0
while True:
logs = self.databricks_client.serving_endpoints.logs(
name=self._generate_an_endpoint_name(),
served_model_name=self.config.model_name,
)
log_lines = logs.logs.split("\n")
# Apply tail if specified and it's the first iteration
if tail is not None and last_log_count == 0:
log_lines = log_lines[-tail:]
# Yield only new lines
for line in log_lines[last_log_count:]:
yield line
last_log_count = len(log_lines)
if not follow:
break
# Add a small delay to avoid excessive API calls
time.sleep(1)
yield from log_generator()
def _generate_an_endpoint_name(self) -> str:
"""Generate a unique name for the Databricks Inference Endpoint.
Returns:
A unique name for the Databricks Inference Endpoint.
"""
return (
f"{self.config.service_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}"
)
databricks_client: databricks.sdk.WorkspaceClient
property
readonly
Get the deployed Databricks inference endpoint.
Returns:
Type | Description |
---|---|
databricks.sdk.WorkspaceClient |
databricks inference endpoint. |
databricks_endpoint: databricks.sdk.service.serving.ServingEndpointDetailed
property
readonly
Get the deployed Hugging Face inference endpoint.
Returns:
Type | Description |
---|---|
databricks.sdk.service.serving.ServingEndpointDetailed |
Databricks inference endpoint. |
prediction_url: Optional[str]
property
readonly
The prediction URI exposed by the prediction service.
Returns:
Type | Description |
---|---|
Optional[str] |
The prediction URI exposed by the prediction service, or None if the service is not yet ready. |
__init__(self, config, **attrs)
special
Initialize the Databricks deployment service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
DatabricksDeploymentConfig |
service configuration |
required |
attrs |
Any |
additional attributes to set on the service |
{} |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def __init__(self, config: DatabricksDeploymentConfig, **attrs: Any):
"""Initialize the Databricks deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
super().__init__(config=config, **attrs)
check_status(self)
Check the the current operational state of the Databricks deployment.
Returns:
Type | Description |
---|---|
Tuple[zenml.services.service_status.ServiceState, str] |
The operational state of the Databricks deployment and a message providing additional information about that state (e.g. a description of the error, if one is encountered). |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
"""Check the the current operational state of the Databricks deployment.
Returns:
The operational state of the Databricks deployment and a message
providing additional information about that state (e.g. a
description of the error, if one is encountered).
"""
try:
status = self.databricks_endpoint.state or None
if (
status
and status.ready
and status.ready == EndpointStateReady.READY
):
return (ServiceState.ACTIVE, "")
elif (
status
and status.config_update
and status.config_update
== EndpointStateConfigUpdate.UPDATE_FAILED
):
return (
ServiceState.ERROR,
"Databricks Inference Endpoint deployment update failed",
)
elif (
status
and status.config_update
and status.config_update
== EndpointStateConfigUpdate.IN_PROGRESS
):
return (ServiceState.PENDING_STARTUP, "")
return (ServiceState.PENDING_STARTUP, "")
except Exception as e:
return (
ServiceState.INACTIVE,
f"Databricks Inference Endpoint deployment is inactive or not found: {e}",
)
deprovision(self, force=False)
Deprovision the remote Databricks deployment instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
force |
bool |
if True, the remote deployment instance will be forcefully deprovisioned. |
False |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def deprovision(self, force: bool = False) -> None:
"""Deprovision the remote Databricks deployment instance.
Args:
force: if True, the remote deployment instance will be
forcefully deprovisioned.
"""
try:
self.databricks_client.serving_endpoints.delete(
name=self._generate_an_endpoint_name()
)
except Exception:
logger.error(
"Databricks Inference Endpoint is deleted or cannot be found."
)
get_client_id_and_secret(self)
Get the Databricks client id and secret.
Exceptions:
Type | Description |
---|---|
ValueError |
If client id and secret are not found. |
Returns:
Type | Description |
---|---|
Tuple[str, str, str] |
Databricks client id and secret. |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def get_client_id_and_secret(self) -> Tuple[str, str, str]:
"""Get the Databricks client id and secret.
Raises:
ValueError: If client id and secret are not found.
Returns:
Databricks client id and secret.
"""
client = Client()
client_id = None
client_secret = None
host = None
from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
DatabricksModelDeployer,
)
model_deployer = client.active_stack.model_deployer
if not isinstance(model_deployer, DatabricksModelDeployer):
raise ValueError(
"DatabricksModelDeployer is not active in the stack."
)
host = model_deployer.config.host
self.config.host = host
if model_deployer.config.secret_name:
secret = client.get_secret(model_deployer.config.secret_name)
client_id = secret.secret_values["client_id"]
client_secret = secret.secret_values["client_secret"]
else:
client_id = model_deployer.config.client_id
client_secret = model_deployer.config.client_secret
if not client_id:
raise ValueError("Client id not found.")
if not client_secret:
raise ValueError("Client secret not found.")
if not host:
raise ValueError("Host not found.")
return host, client_id, client_secret
get_logs(self, follow=False, tail=None)
Retrieve the service logs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
follow |
bool |
if True, the logs will be streamed as they are written |
False |
tail |
Optional[int] |
only retrieve the last NUM lines of log output. |
None |
Yields:
Type | Description |
---|---|
Generator[str, bool, NoneType] |
A generator that can be accessed to get the service logs. |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def get_logs(
self, follow: bool = False, tail: Optional[int] = None
) -> Generator[str, bool, None]:
"""Retrieve the service logs.
Args:
follow: if True, the logs will be streamed as they are written
tail: only retrieve the last NUM lines of log output.
Yields:
A generator that can be accessed to get the service logs.
"""
logger.info(
"Databricks Endpoints provides access to the logs of your Endpoints through the UI in the `Logs` tab of your Endpoint"
)
def log_generator() -> Generator[str, bool, None]:
last_log_count = 0
while True:
logs = self.databricks_client.serving_endpoints.logs(
name=self._generate_an_endpoint_name(),
served_model_name=self.config.model_name,
)
log_lines = logs.logs.split("\n")
# Apply tail if specified and it's the first iteration
if tail is not None and last_log_count == 0:
log_lines = log_lines[-tail:]
# Yield only new lines
for line in log_lines[last_log_count:]:
yield line
last_log_count = len(log_lines)
if not follow:
break
# Add a small delay to avoid excessive API calls
time.sleep(1)
yield from log_generator()
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Union[NDArray[Any], pandas.DataFrame] |
The input data for the prediction. |
required |
Returns:
Type | Description |
---|---|
NDArray[Any] |
The prediction result. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not running |
ValueError |
if the endpoint secret name is not provided. |
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def predict(
self, request: Union["NDArray[Any]", pd.DataFrame]
) -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: The input data for the prediction.
Returns:
The prediction result.
Raises:
Exception: if the service is not running
ValueError: if the endpoint secret name is not provided.
"""
if not self.is_running:
raise Exception(
"Databricks endpoint inference service is not running. "
"Please start the service before making predictions."
)
if self.prediction_url is not None:
if not self.config.endpoint_secret_name:
raise ValueError(
"No endpoint secret name is provided for prediction."
)
databricks_token = Client().get_secret(
self.config.endpoint_secret_name
)
if not databricks_token.secret_values["token"]:
raise ValueError("No databricks token found.")
headers = {
"Authorization": f"Bearer {databricks_token.secret_values['token']}",
"Content-Type": "application/json",
}
if isinstance(request, pd.DataFrame):
response = requests.post( # nosec
self.prediction_url,
json={"instances": request.to_dict("records")},
headers=headers,
)
else:
response = requests.post( # nosec
self.prediction_url,
json={"instances": request.tolist()},
headers=headers,
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
return np.array(response.json()["predictions"])
provision(self)
Provision or update remote Databricks deployment instance.
Source code in zenml/integrations/databricks/services/databricks_deployment.py
def provision(self) -> None:
"""Provision or update remote Databricks deployment instance."""
from databricks.sdk.service.serving import (
ServedModelInputWorkloadSize,
ServedModelInputWorkloadType,
)
tags = []
for key, value in self._get_databricks_deployment_labels().items():
tags.append(EndpointTag(key=key, value=value))
# Attempt to create and wait for the inference endpoint
served_model = ServedModelInput(
model_name=self.config.model_name,
model_version=self.config.model_version,
scale_to_zero_enabled=self.config.scale_to_zero_enabled,
workload_type=ServedModelInputWorkloadType(
self.config.workload_type
),
workload_size=ServedModelInputWorkloadSize(
self.config.workload_size
),
)
databricks_endpoint = (
self.databricks_client.serving_endpoints.create_and_wait(
name=self._generate_an_endpoint_name(),
config=EndpointCoreConfigInput(
served_models=[served_model],
),
tags=tags,
)
)
# Check if the endpoint URL is available after provisioning
if databricks_endpoint.endpoint_url:
logger.info(
f"Databricks inference endpoint successfully deployed and available. Endpoint URL: {databricks_endpoint.endpoint_url}"
)
else:
logger.error(
"Failed to start Databricks inference endpoint service: No URL available, please check the Databricks console for more details."
)
DatabricksServiceStatus (ServiceStatus)
Databricks service status.
Source code in zenml/integrations/databricks/services/databricks_deployment.py
class DatabricksServiceStatus(ServiceStatus):
"""Databricks service status."""
utils
special
Utilities for Databricks integration.
databricks_utils
Databricks utilities.
convert_step_to_task(task_name, command, arguments, libraries=None, depends_on=None, zenml_project_wheel=None, job_cluster_key=None)
Convert a ZenML step to a Databricks task.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task_name |
str |
Name of the task. |
required |
command |
str |
Command to run. |
required |
arguments |
List[str] |
Arguments to pass to the command. |
required |
libraries |
Optional[List[str]] |
List of libraries to install. |
None |
depends_on |
Optional[List[str]] |
List of tasks to depend on. |
None |
zenml_project_wheel |
Optional[str] |
Path to the ZenML project wheel. |
None |
job_cluster_key |
Optional[str] |
ID of the Databricks job_cluster_key. |
None |
Returns:
Type | Description |
---|---|
databricks.sdk.service.jobs.Task |
Databricks task. |
Source code in zenml/integrations/databricks/utils/databricks_utils.py
def convert_step_to_task(
task_name: str,
command: str,
arguments: List[str],
libraries: Optional[List[str]] = None,
depends_on: Optional[List[str]] = None,
zenml_project_wheel: Optional[str] = None,
job_cluster_key: Optional[str] = None,
) -> DatabricksTask:
"""Convert a ZenML step to a Databricks task.
Args:
task_name: Name of the task.
command: Command to run.
arguments: Arguments to pass to the command.
libraries: List of libraries to install.
depends_on: List of tasks to depend on.
zenml_project_wheel: Path to the ZenML project wheel.
job_cluster_key: ID of the Databricks job_cluster_key.
Returns:
Databricks task.
"""
db_libraries = []
if libraries:
for library in libraries:
db_libraries.append(Library(pypi=PythonPyPiLibrary(library)))
db_libraries.append(Library(whl=zenml_project_wheel))
db_libraries.append(
Library(pypi=PythonPyPiLibrary(f"zenml=={__version__}"))
)
return DatabricksTask(
task_key=task_name,
job_cluster_key=job_cluster_key,
libraries=db_libraries,
python_wheel_task=PythonWheelTask(
package_name="zenml",
entry_point=command,
parameters=arguments,
),
depends_on=[TaskDependency(task) for task in depends_on]
if depends_on
else None,
)
sanitize_labels(labels)
Update the label values to be valid Kubernetes labels.
See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
Dict[str, str] |
the labels to sanitize. |
required |
Source code in zenml/integrations/databricks/utils/databricks_utils.py
def sanitize_labels(labels: Dict[str, str]) -> None:
"""Update the label values to be valid Kubernetes labels.
See:
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
Args:
labels: the labels to sanitize.
"""
for key, value in labels.items():
# Kubernetes labels must be alphanumeric, no longer than
# 63 characters, and must begin and end with an alphanumeric
# character ([a-z0-9A-Z])
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
"-_."
)