Skip to content

Tensorflow

zenml.integrations.tensorflow special

Initialization for TensorFlow integration.

TensorflowIntegration (Integration)

Definition of Tensorflow integration for ZenML.

Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
    """Definition of Tensorflow integration for ZenML."""

    NAME = TENSORFLOW
    REQUIREMENTS = []
    REQUIREMENTS_IGNORED_ON_UNINSTALL = ["typing-extensions"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        # need to import this explicitly to load the Tensorflow file IO support
        # for S3 and other file systems
        if not platform.system() == "Darwin" or not platform.machine() == "arm64":
            import tensorflow_io  # type: ignore

        from zenml.integrations.tensorflow import materializers  # noqa

    @classmethod
    def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
        """Defines platform specific requirements for the integration.

        Args:
            target_os: The target operating system.

        Returns:
            A list of requirements.
        """
        target_os = target_os or platform.system()
        if target_os == "Darwin" and platform.machine() == "arm64":
            requirements = [
                "tensorflow-macos>=2.12,<2.15",
            ]
        else:
            requirements = [
                "tensorflow>=2.12,<2.15",
                "tensorflow_io>=0.24.0",
            ]
        if sys.version_info.minor == 8:
            requirements.append("typing-extensions>=4.6.1")
        return requirements

activate() classmethod

Activates the integration.

Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    # need to import this explicitly to load the Tensorflow file IO support
    # for S3 and other file systems
    if not platform.system() == "Darwin" or not platform.machine() == "arm64":
        import tensorflow_io  # type: ignore

    from zenml.integrations.tensorflow import materializers  # noqa

get_requirements(target_os=None) classmethod

Defines platform specific requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
    """Defines platform specific requirements for the integration.

    Args:
        target_os: The target operating system.

    Returns:
        A list of requirements.
    """
    target_os = target_os or platform.system()
    if target_os == "Darwin" and platform.machine() == "arm64":
        requirements = [
            "tensorflow-macos>=2.12,<2.15",
        ]
    else:
        requirements = [
            "tensorflow>=2.12,<2.15",
            "tensorflow_io>=0.24.0",
        ]
    if sys.version_info.minor == 8:
        requirements.append("typing-extensions>=4.6.1")
    return requirements

materializers special

Initialization for the TensorFlow materializers.

keras_materializer

Implementation of the TensorFlow Keras materializer.

KerasMaterializer (BaseMaterializer)

Materializer to read/write Keras models.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
    """Materializer to read/write Keras models."""

    ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (
        tf.keras.Model,
        tf_keras.Model,
    )
    ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL
    MODEL_FILE_NAME = "model.keras"

    def load(self, data_type: Type[Any]) -> tf_keras.Model:
        """Reads and returns a Keras model after copying it to temporary path.

        Args:
            data_type: The type of the data to read.

        Returns:
            A keras.Model model.
        """
        with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
            # Copy from artifact store to temporary directory
            temp_model_file = os.path.join(temp_dir, self.MODEL_FILE_NAME)
            io_utils.copy_dir(self.uri, temp_dir)

            # Load the model from the temporary directory
            model = tf.keras.models.load_model(temp_model_file)

            return model

    def save(self, model: tf_keras.Model) -> None:
        """Writes a keras model to the artifact store.

        Args:
            model: A keras.Model model.
        """
        with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
            temp_model_file = os.path.join(temp_dir, self.MODEL_FILE_NAME)
            model.save(temp_model_file)
            io_utils.copy_dir(temp_dir, self.uri)

    def extract_metadata(
        self, model: tf_keras.Model
    ) -> Dict[str, "MetadataType"]:
        """Extract metadata from the given `Model` object.

        Args:
            model: The `Model` object to extract metadata from.

        Returns:
            The extracted metadata as a dictionary.
        """
        return {
            "num_layers": len(model.layers),
            "num_params": count_params(model.weights),
            "num_trainable_params": count_params(model.trainable_weights),
        }
extract_metadata(self, model)

Extract metadata from the given Model object.

Parameters:

Name Type Description Default
model tensorflow.python.keras.Model

The Model object to extract metadata from.

required

Returns:

Type Description
Dict[str, MetadataType]

The extracted metadata as a dictionary.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def extract_metadata(
    self, model: tf_keras.Model
) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given `Model` object.

    Args:
        model: The `Model` object to extract metadata from.

    Returns:
        The extracted metadata as a dictionary.
    """
    return {
        "num_layers": len(model.layers),
        "num_params": count_params(model.weights),
        "num_trainable_params": count_params(model.trainable_weights),
    }
load(self, data_type)

Reads and returns a Keras model after copying it to temporary path.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
tensorflow.python.keras.Model

A keras.Model model.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def load(self, data_type: Type[Any]) -> tf_keras.Model:
    """Reads and returns a Keras model after copying it to temporary path.

    Args:
        data_type: The type of the data to read.

    Returns:
        A keras.Model model.
    """
    with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
        # Copy from artifact store to temporary directory
        temp_model_file = os.path.join(temp_dir, self.MODEL_FILE_NAME)
        io_utils.copy_dir(self.uri, temp_dir)

        # Load the model from the temporary directory
        model = tf.keras.models.load_model(temp_model_file)

        return model
save(self, model)

Writes a keras model to the artifact store.

Parameters:

Name Type Description Default
model tensorflow.python.keras.Model

A keras.Model model.

required
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def save(self, model: tf_keras.Model) -> None:
    """Writes a keras model to the artifact store.

    Args:
        model: A keras.Model model.
    """
    with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
        temp_model_file = os.path.join(temp_dir, self.MODEL_FILE_NAME)
        model.save(temp_model_file)
        io_utils.copy_dir(temp_dir, self.uri)

tf_dataset_materializer

Implementation of the TensorFlow dataset materializer.

TensorflowDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from tf.data.Dataset.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from tf.data.Dataset."""

    ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (tf.data.Dataset,)
    ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA

    def load(self, data_type: Type[Any]) -> Any:
        """Reads data into tf.data.Dataset.

        Args:
            data_type: The type of the data to read.

        Returns:
            A tf.data.Dataset object.
        """
        with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
            io_utils.copy_dir(self.uri, temp_dir)
            path = os.path.join(temp_dir, DEFAULT_FILENAME)
            dataset = tf.data.Dataset.load(path)
            return dataset

    def save(self, dataset: tf.data.Dataset) -> None:
        """Persists a tf.data.Dataset object.

        Args:
            dataset: The dataset to persist.
        """
        with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
            path = os.path.join(temp_dir, DEFAULT_FILENAME)
            tf.data.Dataset.save(
                dataset, path, compression=None, shard_func=None
            )
            io_utils.copy_dir(temp_dir, self.uri)

    def extract_metadata(
        self, dataset: tf.data.Dataset
    ) -> Dict[str, "MetadataType"]:
        """Extract metadata from the given `Dataset` object.

        Args:
            dataset: The `Dataset` object to extract metadata from.

        Returns:
            The extracted metadata as a dictionary.
        """
        return {"length": len(dataset)}
extract_metadata(self, dataset)

Extract metadata from the given Dataset object.

Parameters:

Name Type Description Default
dataset tensorflow.data.Dataset

The Dataset object to extract metadata from.

required

Returns:

Type Description
Dict[str, MetadataType]

The extracted metadata as a dictionary.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def extract_metadata(
    self, dataset: tf.data.Dataset
) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given `Dataset` object.

    Args:
        dataset: The `Dataset` object to extract metadata from.

    Returns:
        The extracted metadata as a dictionary.
    """
    return {"length": len(dataset)}
load(self, data_type)

Reads data into tf.data.Dataset.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Any

A tf.data.Dataset object.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def load(self, data_type: Type[Any]) -> Any:
    """Reads data into tf.data.Dataset.

    Args:
        data_type: The type of the data to read.

    Returns:
        A tf.data.Dataset object.
    """
    with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
        io_utils.copy_dir(self.uri, temp_dir)
        path = os.path.join(temp_dir, DEFAULT_FILENAME)
        dataset = tf.data.Dataset.load(path)
        return dataset
save(self, dataset)

Persists a tf.data.Dataset object.

Parameters:

Name Type Description Default
dataset tensorflow.data.Dataset

The dataset to persist.

required
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def save(self, dataset: tf.data.Dataset) -> None:
    """Persists a tf.data.Dataset object.

    Args:
        dataset: The dataset to persist.
    """
    with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
        path = os.path.join(temp_dir, DEFAULT_FILENAME)
        tf.data.Dataset.save(
            dataset, path, compression=None, shard_func=None
        )
        io_utils.copy_dir(temp_dir, self.uri)