Metadata Stores
zenml.metadata_stores
special
Initialization of ZenML's metadata stores.
The configuration of each pipeline, step, backend, and produced artifacts are
all tracked within the metadata store. The metadata store is an SQL database,
and can be sqlite
or mysql
.
Metadata are the pieces of information tracked about the pipelines, experiments and configurations that you are running with ZenML. Metadata are stored inside the metadata store.
base_metadata_store
Base implementation of a metadata store.
BaseMetadataStore (StackComponent, ABC)
pydantic-model
Base class for all ZenML metadata stores.
Source code in zenml/metadata_stores/base_metadata_store.py
class BaseMetadataStore(StackComponent, ABC):
"""Base class for all ZenML metadata stores."""
# Class Configuration
TYPE: ClassVar[StackComponentType] = StackComponentType.METADATA_STORE
upgrade_migration_enabled: bool = True
_store: Optional[metadata_store.MetadataStore] = None
@property
def store(self) -> metadata_store.MetadataStore:
"""General property that hooks into TFX metadata store.
Returns:
metadata_store.MetadataStore: TFX metadata store.
"""
if self._store is None:
config = self.get_tfx_metadata_config()
self._store = metadata_store.MetadataStore(
config,
enable_upgrade_migration=self.upgrade_migration_enabled
and isinstance(config, metadata_store_pb2.ConnectionConfig),
)
return self._store
@abstractmethod
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config.
Returns:
tfx metadata config.
"""
raise NotImplementedError
@property
def step_type_mapping(self) -> Dict[int, str]:
"""Maps type_ids to step names.
Returns:
Dict[int, str]: a mapping from type_ids to step names.
"""
return {
type_.id: type_.name for type_ in self.store.get_execution_types()
}
def _check_if_executions_belong_to_pipeline(
self,
executions: List[proto.Execution],
pipeline: PipelineView,
) -> bool:
"""Returns `True` if the executions are associated with the pipeline context.
Args:
executions: List of executions.
pipeline: Pipeline to check.
Returns:
`True` if the executions are associated with the pipeline context.
"""
for execution in executions:
associated_contexts = self.store.get_contexts_by_execution(
execution.id
)
for context in associated_contexts:
if context.id == pipeline._id: # noqa
return True
return False
def _get_step_view_from_execution(
self, execution: proto.Execution
) -> StepView:
"""Get original StepView from an execution.
Args:
execution: proto.Execution object from mlmd store.
Returns:
Original `StepView` derived from the proto.Execution.
Raises:
KeyError: If the execution is not associated with a step.
"""
impl_name = self.step_type_mapping[execution.type_id].split(".")[-1]
step_name_property = execution.custom_properties.get(
INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME,
None,
)
if step_name_property:
step_name = json.loads(step_name_property.string_value)
else:
raise KeyError(
f"Step name missing for execution with ID {execution.id}. "
f"This error probably occurs because you're using ZenML "
f"version 0.5.4 or newer but your metadata store contains "
f"data from previous versions."
)
step_parameters = {}
for k, v in execution.custom_properties.items():
if not k.startswith(INTERNAL_EXECUTION_PARAMETER_PREFIX):
try:
step_parameters[k] = json.loads(v.string_value)
except JSONDecodeError:
# this means there is a property in there that is neither
# an internal one or one created by zenml. Therefore, we can
# ignore it
pass
# TODO [ENG-222]: This is a lot of querying to the metadata store. We
# should refactor and make it nicer. Probably it makes more sense
# to first get `executions_ids_for_current_run` and then filter on
# `event.execution_id in execution_ids_for_current_run`.
# Core logic here is that we get the event of this particular execution
# id that gives us the artifacts of this execution. We then go through
# all `input` artifacts of this execution and get all events related to
# that artifact. This in turn gives us other events for which this
# artifact was an `output` artifact. Then we simply need to sort by
# time to get the most recent execution (i.e. step) that produced that
# particular artifact.
events_for_execution = self.store.get_events_by_execution_ids(
[execution.id]
)
parents_step_ids = set()
for current_event in events_for_execution:
if current_event.type == current_event.INPUT:
# this means the artifact is an input artifact
events_for_input_artifact = [
e
for e in self.store.get_events_by_artifact_ids(
[current_event.artifact_id]
)
# should be output type and should NOT be the same id as
# the execution we are querying and it should be BEFORE
# the time of the current event.
if e.type == e.OUTPUT
and e.execution_id != current_event.execution_id
and e.milliseconds_since_epoch
< current_event.milliseconds_since_epoch
]
# sort by time
events_for_input_artifact.sort(
key=lambda x: x.milliseconds_since_epoch # type: ignore[no-any-return] # noqa
)
# take the latest one and add execution to the parents.
parents_step_ids.add(events_for_input_artifact[-1].execution_id)
return StepView(
id_=execution.id,
parents_step_ids=list(parents_step_ids),
entrypoint_name=impl_name,
name=step_name,
parameters=step_parameters,
metadata_store=self,
)
def get_pipelines(self) -> List[PipelineView]:
"""Returns a list of all pipelines stored in this metadata store.
Returns:
List[PipelineView]: a list of all pipelines stored in this metadata store.
"""
pipelines = []
for pipeline_context in self.store.get_contexts_by_type(
PIPELINE_CONTEXT_TYPE_NAME
):
pipeline = PipelineView(
id_=pipeline_context.id,
name=pipeline_context.name,
metadata_store=self,
)
pipelines.append(pipeline)
logger.debug("Fetched %d pipelines.", len(pipelines))
return pipelines
def get_pipeline(self, pipeline_name: str) -> Optional[PipelineView]:
"""Returns a pipeline for the given name.
Args:
pipeline_name: Name of the pipeline.
Returns:
PipelineView if found, None otherwise.
"""
pipeline_context = self.store.get_context_by_type_and_name(
PIPELINE_CONTEXT_TYPE_NAME, pipeline_name
)
if pipeline_context:
logger.debug("Fetched pipeline with name '%s'", pipeline_name)
return PipelineView(
id_=pipeline_context.id,
name=pipeline_context.name,
metadata_store=self,
)
else:
logger.info("No pipelines found for name '%s'", pipeline_name)
return None
def get_pipeline_runs(
self, pipeline: PipelineView
) -> Dict[str, PipelineRunView]:
"""Gets all runs for the given pipeline.
Args:
pipeline: a Pipeline object for which you want the runs.
Returns:
A dictionary of pipeline run names to PipelineRunView.
"""
all_pipeline_runs = self.store.get_contexts_by_type(
PIPELINE_RUN_CONTEXT_TYPE_NAME
)
runs: Dict[str, PipelineRunView] = OrderedDict()
for run in all_pipeline_runs:
executions = self.store.get_executions_by_context(run.id)
if self._check_if_executions_belong_to_pipeline(
executions, pipeline
):
run_view = PipelineRunView(
id_=run.id,
name=run.name,
executions=executions,
metadata_store=self,
)
runs[run.name] = run_view
logger.debug(
"Fetched %d pipeline runs for pipeline named '%s'.",
len(runs),
pipeline.name,
)
return runs
def get_pipeline_run(
self, pipeline: PipelineView, run_name: str
) -> Optional[PipelineRunView]:
"""Gets a specific run for the given pipeline.
Args:
pipeline: The pipeline for which to get the run.
run_name: The name of the run to get.
Returns:
The pipeline run with the given name.
"""
run = self.store.get_context_by_type_and_name(
PIPELINE_RUN_CONTEXT_TYPE_NAME, run_name
)
if not run:
# No context found for the given run name
return None
executions = self.store.get_executions_by_context(run.id)
if self._check_if_executions_belong_to_pipeline(executions, pipeline):
logger.debug("Fetched pipeline run with name '%s'", run_name)
return PipelineRunView(
id_=run.id,
name=run.name,
executions=executions,
metadata_store=self,
)
logger.info("No pipeline run found for name '%s'", run_name)
return None
def get_pipeline_run_steps(
self, pipeline_run: PipelineRunView
) -> Dict[str, StepView]:
"""Gets all steps for the given pipeline run.
Args:
pipeline_run: The pipeline run to get the steps for.
Returns:
A dictionary of step names to step views.
"""
steps: Dict[str, StepView] = OrderedDict()
# reverse the executions as they get returned in reverse chronological
# order from the metadata store
for execution in reversed(pipeline_run._executions): # noqa
step = self._get_step_view_from_execution(execution)
steps[step.name] = step
logger.debug(
"Fetched %d steps for pipeline run '%s'.",
len(steps),
pipeline_run.name,
)
return steps
def get_step_by_id(self, step_id: int) -> StepView:
"""Gets a `StepView` by its ID.
Args:
step_id (int): The ID of the step to get.
Returns:
StepView: The `StepView` with the given ID.
"""
execution = self.store.get_executions_by_id([step_id])[0]
return self._get_step_view_from_execution(execution)
def get_step_status(self, step: StepView) -> ExecutionStatus:
"""Gets the execution status of a single step.
Args:
step (StepView): The step to get the status for.
Returns:
ExecutionStatus: The status of the step.
"""
proto = self.store.get_executions_by_id([step._id])[0] # noqa
state = proto.last_known_state
if state == proto.COMPLETE:
return ExecutionStatus.COMPLETED
elif state == proto.RUNNING:
return ExecutionStatus.RUNNING
elif state == proto.CACHED:
return ExecutionStatus.CACHED
else:
return ExecutionStatus.FAILED
def get_step_artifacts(
self, step: StepView
) -> Tuple[Dict[str, ArtifactView], Dict[str, ArtifactView]]:
"""Returns input and output artifacts for the given step.
Args:
step: The step for which to get the artifacts.
Returns:
A tuple (inputs, outputs) where inputs and outputs
are both Dicts mapping artifact names
to the input and output artifacts respectively.
"""
# maps artifact types to their string representation
artifact_type_mapping = {
type_.id: type_.name for type_ in self.store.get_artifact_types()
}
events = self.store.get_events_by_execution_ids([step._id]) # noqa
artifacts = self.store.get_artifacts_by_id(
[event.artifact_id for event in events]
)
inputs: Dict[str, ArtifactView] = {}
outputs: Dict[str, ArtifactView] = {}
# sort them according to artifact_id's so that the zip works.
events.sort(key=lambda x: x.artifact_id)
artifacts.sort(key=lambda x: x.id)
for event_proto, artifact_proto in zip(events, artifacts):
artifact_type = artifact_type_mapping[artifact_proto.type_id]
artifact_name = event_proto.path.steps[0].key
materializer = artifact_proto.properties[
MATERIALIZER_PROPERTY_KEY
].string_value
data_type = artifact_proto.properties[
DATATYPE_PROPERTY_KEY
].string_value
parent_step_id = step.id
if event_proto.type == event_proto.INPUT:
# In the case that this is an input event, we actually need
# to resolve it via its parents outputs.
for parent in step.parent_steps:
for a in parent.outputs.values():
if artifact_proto.id == a.id:
parent_step_id = parent.id
artifact = ArtifactView(
id_=event_proto.artifact_id,
type_=artifact_type,
uri=artifact_proto.uri,
materializer=materializer,
data_type=data_type,
metadata_store=self,
parent_step_id=parent_step_id,
)
if event_proto.type == event_proto.INPUT:
inputs[artifact_name] = artifact
elif event_proto.type == event_proto.OUTPUT:
outputs[artifact_name] = artifact
logger.debug(
"Fetched %d inputs and %d outputs for step '%s'.",
len(inputs),
len(outputs),
step.entrypoint_name,
)
return inputs, outputs
def get_producer_step_from_artifact(
self, artifact: ArtifactView
) -> StepView:
"""Returns original StepView from an ArtifactView.
Args:
artifact: ArtifactView to be queried.
Returns:
Original StepView that produced the artifact.
"""
executions_ids = set(
event.execution_id
for event in self.store.get_events_by_artifact_ids([artifact.id])
if event.type == event.OUTPUT
)
execution = self.store.get_executions_by_id(executions_ids)[0]
return self._get_step_view_from_execution(execution)
step_type_mapping: Dict[int, str]
property
readonly
Maps type_ids to step names.
Returns:
Type | Description |
---|---|
Dict[int, str] |
a mapping from type_ids to step names. |
store: MetadataStore
property
readonly
General property that hooks into TFX metadata store.
Returns:
Type | Description |
---|---|
metadata_store.MetadataStore |
TFX metadata store. |
get_pipeline(self, pipeline_name)
Returns a pipeline for the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline. |
required |
Returns:
Type | Description |
---|---|
Optional[zenml.post_execution.pipeline.PipelineView] |
PipelineView if found, None otherwise. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline(self, pipeline_name: str) -> Optional[PipelineView]:
"""Returns a pipeline for the given name.
Args:
pipeline_name: Name of the pipeline.
Returns:
PipelineView if found, None otherwise.
"""
pipeline_context = self.store.get_context_by_type_and_name(
PIPELINE_CONTEXT_TYPE_NAME, pipeline_name
)
if pipeline_context:
logger.debug("Fetched pipeline with name '%s'", pipeline_name)
return PipelineView(
id_=pipeline_context.id,
name=pipeline_context.name,
metadata_store=self,
)
else:
logger.info("No pipelines found for name '%s'", pipeline_name)
return None
get_pipeline_run(self, pipeline, run_name)
Gets a specific run for the given pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
PipelineView |
The pipeline for which to get the run. |
required |
run_name |
str |
The name of the run to get. |
required |
Returns:
Type | Description |
---|---|
Optional[zenml.post_execution.pipeline_run.PipelineRunView] |
The pipeline run with the given name. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline_run(
self, pipeline: PipelineView, run_name: str
) -> Optional[PipelineRunView]:
"""Gets a specific run for the given pipeline.
Args:
pipeline: The pipeline for which to get the run.
run_name: The name of the run to get.
Returns:
The pipeline run with the given name.
"""
run = self.store.get_context_by_type_and_name(
PIPELINE_RUN_CONTEXT_TYPE_NAME, run_name
)
if not run:
# No context found for the given run name
return None
executions = self.store.get_executions_by_context(run.id)
if self._check_if_executions_belong_to_pipeline(executions, pipeline):
logger.debug("Fetched pipeline run with name '%s'", run_name)
return PipelineRunView(
id_=run.id,
name=run.name,
executions=executions,
metadata_store=self,
)
logger.info("No pipeline run found for name '%s'", run_name)
return None
get_pipeline_run_steps(self, pipeline_run)
Gets all steps for the given pipeline run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_run |
PipelineRunView |
The pipeline run to get the steps for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.post_execution.step.StepView] |
A dictionary of step names to step views. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline_run_steps(
self, pipeline_run: PipelineRunView
) -> Dict[str, StepView]:
"""Gets all steps for the given pipeline run.
Args:
pipeline_run: The pipeline run to get the steps for.
Returns:
A dictionary of step names to step views.
"""
steps: Dict[str, StepView] = OrderedDict()
# reverse the executions as they get returned in reverse chronological
# order from the metadata store
for execution in reversed(pipeline_run._executions): # noqa
step = self._get_step_view_from_execution(execution)
steps[step.name] = step
logger.debug(
"Fetched %d steps for pipeline run '%s'.",
len(steps),
pipeline_run.name,
)
return steps
get_pipeline_runs(self, pipeline)
Gets all runs for the given pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
PipelineView |
a Pipeline object for which you want the runs. |
required |
Returns:
Type | Description |
---|---|
Dict[str, zenml.post_execution.pipeline_run.PipelineRunView] |
A dictionary of pipeline run names to PipelineRunView. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline_runs(
self, pipeline: PipelineView
) -> Dict[str, PipelineRunView]:
"""Gets all runs for the given pipeline.
Args:
pipeline: a Pipeline object for which you want the runs.
Returns:
A dictionary of pipeline run names to PipelineRunView.
"""
all_pipeline_runs = self.store.get_contexts_by_type(
PIPELINE_RUN_CONTEXT_TYPE_NAME
)
runs: Dict[str, PipelineRunView] = OrderedDict()
for run in all_pipeline_runs:
executions = self.store.get_executions_by_context(run.id)
if self._check_if_executions_belong_to_pipeline(
executions, pipeline
):
run_view = PipelineRunView(
id_=run.id,
name=run.name,
executions=executions,
metadata_store=self,
)
runs[run.name] = run_view
logger.debug(
"Fetched %d pipeline runs for pipeline named '%s'.",
len(runs),
pipeline.name,
)
return runs
get_pipelines(self)
Returns a list of all pipelines stored in this metadata store.
Returns:
Type | Description |
---|---|
List[PipelineView] |
a list of all pipelines stored in this metadata store. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipelines(self) -> List[PipelineView]:
"""Returns a list of all pipelines stored in this metadata store.
Returns:
List[PipelineView]: a list of all pipelines stored in this metadata store.
"""
pipelines = []
for pipeline_context in self.store.get_contexts_by_type(
PIPELINE_CONTEXT_TYPE_NAME
):
pipeline = PipelineView(
id_=pipeline_context.id,
name=pipeline_context.name,
metadata_store=self,
)
pipelines.append(pipeline)
logger.debug("Fetched %d pipelines.", len(pipelines))
return pipelines
get_producer_step_from_artifact(self, artifact)
Returns original StepView from an ArtifactView.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
artifact |
ArtifactView |
ArtifactView to be queried. |
required |
Returns:
Type | Description |
---|---|
StepView |
Original StepView that produced the artifact. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_producer_step_from_artifact(
self, artifact: ArtifactView
) -> StepView:
"""Returns original StepView from an ArtifactView.
Args:
artifact: ArtifactView to be queried.
Returns:
Original StepView that produced the artifact.
"""
executions_ids = set(
event.execution_id
for event in self.store.get_events_by_artifact_ids([artifact.id])
if event.type == event.OUTPUT
)
execution = self.store.get_executions_by_id(executions_ids)[0]
return self._get_step_view_from_execution(execution)
get_step_artifacts(self, step)
Returns input and output artifacts for the given step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
StepView |
The step for which to get the artifacts. |
required |
Returns:
Type | Description |
---|---|
Tuple[Dict[str, zenml.post_execution.artifact.ArtifactView], Dict[str, zenml.post_execution.artifact.ArtifactView]] |
A tuple (inputs, outputs) where inputs and outputs are both Dicts mapping artifact names to the input and output artifacts respectively. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_step_artifacts(
self, step: StepView
) -> Tuple[Dict[str, ArtifactView], Dict[str, ArtifactView]]:
"""Returns input and output artifacts for the given step.
Args:
step: The step for which to get the artifacts.
Returns:
A tuple (inputs, outputs) where inputs and outputs
are both Dicts mapping artifact names
to the input and output artifacts respectively.
"""
# maps artifact types to their string representation
artifact_type_mapping = {
type_.id: type_.name for type_ in self.store.get_artifact_types()
}
events = self.store.get_events_by_execution_ids([step._id]) # noqa
artifacts = self.store.get_artifacts_by_id(
[event.artifact_id for event in events]
)
inputs: Dict[str, ArtifactView] = {}
outputs: Dict[str, ArtifactView] = {}
# sort them according to artifact_id's so that the zip works.
events.sort(key=lambda x: x.artifact_id)
artifacts.sort(key=lambda x: x.id)
for event_proto, artifact_proto in zip(events, artifacts):
artifact_type = artifact_type_mapping[artifact_proto.type_id]
artifact_name = event_proto.path.steps[0].key
materializer = artifact_proto.properties[
MATERIALIZER_PROPERTY_KEY
].string_value
data_type = artifact_proto.properties[
DATATYPE_PROPERTY_KEY
].string_value
parent_step_id = step.id
if event_proto.type == event_proto.INPUT:
# In the case that this is an input event, we actually need
# to resolve it via its parents outputs.
for parent in step.parent_steps:
for a in parent.outputs.values():
if artifact_proto.id == a.id:
parent_step_id = parent.id
artifact = ArtifactView(
id_=event_proto.artifact_id,
type_=artifact_type,
uri=artifact_proto.uri,
materializer=materializer,
data_type=data_type,
metadata_store=self,
parent_step_id=parent_step_id,
)
if event_proto.type == event_proto.INPUT:
inputs[artifact_name] = artifact
elif event_proto.type == event_proto.OUTPUT:
outputs[artifact_name] = artifact
logger.debug(
"Fetched %d inputs and %d outputs for step '%s'.",
len(inputs),
len(outputs),
step.entrypoint_name,
)
return inputs, outputs
get_step_by_id(self, step_id)
Gets a StepView
by its ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_id |
int |
The ID of the step to get. |
required |
Returns:
Type | Description |
---|---|
StepView |
The |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_step_by_id(self, step_id: int) -> StepView:
"""Gets a `StepView` by its ID.
Args:
step_id (int): The ID of the step to get.
Returns:
StepView: The `StepView` with the given ID.
"""
execution = self.store.get_executions_by_id([step_id])[0]
return self._get_step_view_from_execution(execution)
get_step_status(self, step)
Gets the execution status of a single step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step |
StepView |
The step to get the status for. |
required |
Returns:
Type | Description |
---|---|
ExecutionStatus |
The status of the step. |
Source code in zenml/metadata_stores/base_metadata_store.py
def get_step_status(self, step: StepView) -> ExecutionStatus:
"""Gets the execution status of a single step.
Args:
step (StepView): The step to get the status for.
Returns:
ExecutionStatus: The status of the step.
"""
proto = self.store.get_executions_by_id([step._id])[0] # noqa
state = proto.last_known_state
if state == proto.COMPLETE:
return ExecutionStatus.COMPLETED
elif state == proto.RUNNING:
return ExecutionStatus.RUNNING
elif state == proto.CACHED:
return ExecutionStatus.CACHED
else:
return ExecutionStatus.FAILED
get_tfx_metadata_config(self)
Return tfx metadata config.
Returns:
Type | Description |
---|---|
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig] |
tfx metadata config. |
Source code in zenml/metadata_stores/base_metadata_store.py
@abstractmethod
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config.
Returns:
tfx metadata config.
"""
raise NotImplementedError
mysql_metadata_store
Implementation of a MySQL metadata store.
MySQLMetadataStore (BaseMetadataStore)
pydantic-model
MySQL backend for ZenML metadata store.
Attributes:
Name | Type | Description |
---|---|---|
port |
int |
TCP port where the MySQL server can be accessed. |
host |
str |
MySQL server hostname. |
database |
str |
MySQL database name to use for the metadata store. If not already present on the server, it will be created automatically on first access. |
secret |
Optional[str] |
The name of a ZenML secret that holds credentials. |
username |
Optional[str] |
The database username. It can be configured here, or in the referenced ZenML secret (recommended). |
password |
Optional[str] |
The database password. It can be configured here, or in the referenced ZenML secret (recommended). |
Source code in zenml/metadata_stores/mysql_metadata_store.py
class MySQLMetadataStore(BaseMetadataStore):
"""MySQL backend for ZenML metadata store.
Attributes:
port: TCP port where the MySQL server can be accessed.
host: MySQL server hostname.
database: MySQL database name to use for the metadata store. If not
already present on the server, it will be created automatically
on first access.
secret: The name of a ZenML secret that holds credentials.
username: The database username. It can be configured here, or in the
referenced ZenML secret (recommended).
password: The database password. It can be configured here, or in the
referenced ZenML secret (recommended).
"""
port: int = 3306
host: str
database: str
secret: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
# Class Configuration
FLAVOR: ClassVar[str] = "mysql"
@validator("database")
def _ensure_valid_database_name(
cls,
database: str,
) -> str:
"""Ensures that the database name is valid.
Args:
database: The database name value to validate.
Returns:
The database name if it is valid.
Raises:
ValueError: If the database name is not valid.
"""
regexp = r"^[^\\/?%*:|\"<>.-]{1,64}$"
match = re.match(regexp, database)
if not match:
raise ValueError(
f"The database name does not conform to the required format "
f"rules ({regexp}): {database}"
)
return database
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for MySQL metadata store.
Returns:
The tfx metadata config.
Raises:
RuntimeError: If you have configured your metadata store incorrectly.
"""
config = MySQLDatabaseConfig(
host=self.host,
port=self.port,
database=self.database,
)
secret = self._get_mysql_secret()
# Set the user
if self.username:
if secret and secret.user:
raise RuntimeError(
f"Both the metadata store {self.name} and the secret "
f"{self.secret} within your secrets manager define "
f"a username `{self.username}` and `{secret.user}`. Please "
f"make sure that you only use one."
)
else:
config.user = self.username
else:
if secret and secret.user:
config.user = secret.user
else:
raise RuntimeError(
"Your metadata store does not have a username. Please "
"provide it either by defining it upon registration or "
"through a MySQL secret."
)
# Set the password
if self.password:
if secret and secret.password:
raise RuntimeError(
f"Both the metadata store {self.name} and the secret "
f"{self.secret} within your secrets manager define "
f"a password. Please make sure that you only use one."
)
else:
config.password = self.password
else:
if secret and secret.password:
config.password = secret.password
# Set the SSL configuration if there is one
if secret:
secret_folder = Path(
GlobalConfiguration().config_directory,
"mysql-metadata",
str(self.uuid),
)
ssl_options = {}
# Handle the files
for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
content = getattr(secret, key)
if content:
fileio.makedirs(str(secret_folder))
file_path = Path(secret_folder, f"{key}.pem")
ssl_options[key.lstrip("ssl_")] = str(file_path)
with open(file_path, "w") as f:
f.write(content)
file_path.chmod(0o600)
# Handle additional params
ssl_options["verify_server_cert"] = secret.ssl_verify_server_cert
ssl_options = MySQLDatabaseConfig.SSLOptions(**ssl_options)
config.ssl_options.CopyFrom(ssl_options)
return metadata_store_pb2.ConnectionConfig(mysql=config)
def _get_mysql_secret(self) -> Any:
"""Method which returns a MySQL secret from the secrets manager.
Returns:
Any: The MySQL secret.
Raises:
RuntimeError: If you don't have a secrets manager as part of your stack.
"""
if self.secret:
active_stack = Repository(skip_repository_check=True).active_stack # type: ignore[call-arg]
secret_manager = active_stack.secrets_manager
if secret_manager is None:
raise RuntimeError(
f"The metadata store `{self.name}` that you are using "
f"requires a secret. However, your stack "
f"`{active_stack.name}` does not have a secrets manager."
)
try:
secret = secret_manager.get_secret(self.secret)
from zenml.metadata_stores.mysql_secret_schema import (
MYSQLSecretSchema,
)
if not isinstance(secret, MYSQLSecretSchema):
raise RuntimeError(
f"If you are using a secret with a MySQL Metadata "
f"Store, please make sure to use the schema: "
f"{MYSQLSecretSchema.TYPE}"
)
return secret
except KeyError:
raise RuntimeError(
f"The secret `{self.secret}` used for your MySQL metadata "
f"store `{self.name}` does not exist in your secrets "
f"manager `{secret_manager.name}`."
)
return None
get_tfx_metadata_config(self)
Return tfx metadata config for MySQL metadata store.
Returns:
Type | Description |
---|---|
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig] |
The tfx metadata config. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If you have configured your metadata store incorrectly. |
Source code in zenml/metadata_stores/mysql_metadata_store.py
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for MySQL metadata store.
Returns:
The tfx metadata config.
Raises:
RuntimeError: If you have configured your metadata store incorrectly.
"""
config = MySQLDatabaseConfig(
host=self.host,
port=self.port,
database=self.database,
)
secret = self._get_mysql_secret()
# Set the user
if self.username:
if secret and secret.user:
raise RuntimeError(
f"Both the metadata store {self.name} and the secret "
f"{self.secret} within your secrets manager define "
f"a username `{self.username}` and `{secret.user}`. Please "
f"make sure that you only use one."
)
else:
config.user = self.username
else:
if secret and secret.user:
config.user = secret.user
else:
raise RuntimeError(
"Your metadata store does not have a username. Please "
"provide it either by defining it upon registration or "
"through a MySQL secret."
)
# Set the password
if self.password:
if secret and secret.password:
raise RuntimeError(
f"Both the metadata store {self.name} and the secret "
f"{self.secret} within your secrets manager define "
f"a password. Please make sure that you only use one."
)
else:
config.password = self.password
else:
if secret and secret.password:
config.password = secret.password
# Set the SSL configuration if there is one
if secret:
secret_folder = Path(
GlobalConfiguration().config_directory,
"mysql-metadata",
str(self.uuid),
)
ssl_options = {}
# Handle the files
for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
content = getattr(secret, key)
if content:
fileio.makedirs(str(secret_folder))
file_path = Path(secret_folder, f"{key}.pem")
ssl_options[key.lstrip("ssl_")] = str(file_path)
with open(file_path, "w") as f:
f.write(content)
file_path.chmod(0o600)
# Handle additional params
ssl_options["verify_server_cert"] = secret.ssl_verify_server_cert
ssl_options = MySQLDatabaseConfig.SSLOptions(**ssl_options)
config.ssl_options.CopyFrom(ssl_options)
return metadata_store_pb2.ConnectionConfig(mysql=config)
mysql_secret_schema
Secret schema for MySQL metadata store.
MYSQLSecretSchema (BaseSecretSchema)
pydantic-model
MySQL secret schema.
Attributes:
Name | Type | Description |
---|---|---|
user |
Optional[str] |
database username |
password |
Optional[str] |
database password |
ssl_ca |
Optional[str] |
certificate authority certificate contents. Required for SSL enabled authentication if the CA certificate is not part of the certificates shipped by the operating system. |
ssl_cert |
Optional[str] |
client certificate contents. Required for SSL enabled authentication if client certificates are used. |
ssl_key |
Optional[str] |
client certificate private key contents. Required for SSL enabled if client certificates are used. |
ssl_verify_server_cert |
Optional[bool] |
set to verify the identity of the server against the provided server certificate. |
Source code in zenml/metadata_stores/mysql_secret_schema.py
class MYSQLSecretSchema(BaseSecretSchema):
"""MySQL secret schema.
Attributes:
user: database username
password: database password
ssl_ca: certificate authority certificate contents. Required for SSL
enabled authentication if the CA certificate is not part of the
certificates shipped by the operating system.
ssl_cert: client certificate contents. Required for SSL enabled
authentication if client certificates are used.
ssl_key: client certificate private key contents. Required for SSL
enabled if client certificates are used.
ssl_verify_server_cert: set to verify the identity of the server
against the provided server certificate.
"""
TYPE: ClassVar[str] = MYSQL_METADATA_STORE_SCHEMA_TYPE
user: Optional[str]
password: Optional[str]
ssl_ca: Optional[str]
ssl_cert: Optional[str]
ssl_key: Optional[str]
ssl_verify_server_cert: Optional[bool] = False
sqlite_metadata_store
Metadata store for SQLite.
SQLiteMetadataStore (BaseMetadataStore)
pydantic-model
SQLite backend for ZenML metadata store.
Source code in zenml/metadata_stores/sqlite_metadata_store.py
class SQLiteMetadataStore(BaseMetadataStore):
"""SQLite backend for ZenML metadata store."""
uri: str
# Class Configuration
FLAVOR: ClassVar[str] = "sqlite"
@property
def local_path(self) -> str:
"""Path to the local directory where the SQLite DB is stored.
Returns:
The path to the local directory where the SQLite DB is stored.
"""
return str(Path(self.uri).parent)
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for sqlite metadata store.
Returns:
The tfx metadata config.
"""
return metadata.sqlite_metadata_connection_config(self.uri)
@validator("uri")
def ensure_uri_is_local(cls, uri: str) -> str:
"""Ensures that the metadata store uri is local.
Args:
uri: The metadata store uri.
Returns:
The metadata store uri.
Raises:
ValueError: If the uri is not local.
"""
if io_utils.is_remote(uri):
raise ValueError(
f"Uri '{uri}' specified for SQLiteMetadataStore is not a "
f"local uri."
)
return uri
local_path: str
property
readonly
Path to the local directory where the SQLite DB is stored.
Returns:
Type | Description |
---|---|
str |
The path to the local directory where the SQLite DB is stored. |
ensure_uri_is_local(uri)
classmethod
Ensures that the metadata store uri is local.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The metadata store uri. |
required |
Returns:
Type | Description |
---|---|
str |
The metadata store uri. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the uri is not local. |
Source code in zenml/metadata_stores/sqlite_metadata_store.py
@validator("uri")
def ensure_uri_is_local(cls, uri: str) -> str:
"""Ensures that the metadata store uri is local.
Args:
uri: The metadata store uri.
Returns:
The metadata store uri.
Raises:
ValueError: If the uri is not local.
"""
if io_utils.is_remote(uri):
raise ValueError(
f"Uri '{uri}' specified for SQLiteMetadataStore is not a "
f"local uri."
)
return uri
get_tfx_metadata_config(self)
Return tfx metadata config for sqlite metadata store.
Returns:
Type | Description |
---|---|
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig] |
The tfx metadata config. |
Source code in zenml/metadata_stores/sqlite_metadata_store.py
def get_tfx_metadata_config(
self,
) -> Union[
metadata_store_pb2.ConnectionConfig,
metadata_store_pb2.MetadataStoreClientConfig,
]:
"""Return tfx metadata config for sqlite metadata store.
Returns:
The tfx metadata config.
"""
return metadata.sqlite_metadata_connection_config(self.uri)