Skip to content

Tensorflow

zenml.integrations.tensorflow

Initialization for TensorFlow integration.

Attributes

TENSORFLOW = 'tensorflow' module-attribute

logger = get_logger(__name__) module-attribute

Classes

Integration

Base class for integration in ZenML.

Functions
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
175
176
177
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for r in cls.get_requirements():
        try:
            # First check if the base package is installed
            dist = pkg_resources.get_distribution(r)

            # Next, check if the dependencies (including extras) are
            # installed
            deps: List[Requirement] = []

            _, extras = parse_requirement(r)
            if extras:
                extra_list = extras[1:-1].split(",")
                for extra in extra_list:
                    try:
                        requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                    except pkg_resources.UnknownExtra as e:
                        logger.debug(f"Unknown extra: {str(e)}")
                        return False
                    deps.extend(requirements)
            else:
                deps = dist.requires()

            for ri in deps:
                try:
                    # Remove the "extra == ..." part from the requirement string
                    cleaned_req = re.sub(
                        r"; extra == \"\w+\"", "", str(ri)
                    )
                    pkg_resources.get_distribution(cleaned_req)
                except pkg_resources.DistributionNotFound as e:
                    logger.debug(
                        f"Unable to find required dependency "
                        f"'{e.req}' for requirement '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False
                except pkg_resources.VersionConflict as e:
                    logger.debug(
                        f"Package version '{e.dist}' does not match "
                        f"version '{e.req}' required by '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False

        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"Package version '{e.dist}' does not match version "
                f"'{e.req}' necessary for integration {cls.NAME}."
            )
            return False

    logger.debug(
        f"Integration {cls.NAME} is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
179
180
181
182
183
184
185
186
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@classmethod
def get_requirements(
    cls,
    target_os: Optional[str] = None,
    python_version: Optional[str] = None,
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret
plugin_flavors() -> List[Type[BasePluginFlavor]] classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in src/zenml/integrations/integration.py
188
189
190
191
192
193
194
195
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

TensorflowIntegration

Bases: Integration

Definition of Tensorflow integration for ZenML.

Functions
activate() -> None classmethod

Activates the integration.

Source code in src/zenml/integrations/tensorflow/__init__.py
33
34
35
36
37
38
39
40
41
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    # need to import this explicitly to load the Tensorflow file IO support
    # for S3 and other file systems
    if not platform.system() == "Darwin" or not platform.machine() == "arm64":
        import tensorflow_io  # type: ignore

    from zenml.integrations.tensorflow import materializers  # noqa
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Defines platform specific requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/tensorflow/__init__.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@classmethod
def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None
) -> List[str]:
    """Defines platform specific requirements for the integration.

    Args:
        target_os: The target operating system.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    target_os = target_os or platform.system()
    if target_os == "Darwin" and platform.machine() == "arm64":
        requirements = [
            "tensorflow-macos>=2.12,<2.15",
        ]
    else:
        requirements = [
            "tensorflow>=2.12,<2.15",
            "tensorflow_io>=0.24.0",
        ]

    return requirements

Functions

get_logger(logger_name: str) -> logging.Logger

Main function to get logger name,.

Parameters:

Name Type Description Default
logger_name str

Name of logger to initialize.

required

Returns:

Type Description
Logger

A logger object.

Source code in src/zenml/logger.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def get_logger(logger_name: str) -> logging.Logger:
    """Main function to get logger name,.

    Args:
        logger_name: Name of logger to initialize.

    Returns:
        A logger object.
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(get_logging_level().value)
    logger.addHandler(get_console_handler())

    logger.propagate = False
    return logger

Modules

materializers

Initialization for the TensorFlow materializers.

Classes
Modules
keras_materializer

Implementation of the TensorFlow Keras materializer.

Classes
KerasMaterializer(uri: str, artifact_store: Optional[BaseArtifactStore] = None)

Bases: BaseMaterializer

Materializer to read/write Keras models.

Source code in src/zenml/materializers/base_materializer.py
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self, uri: str, artifact_store: Optional[BaseArtifactStore] = None
):
    """Initializes a materializer with the given URI.

    Args:
        uri: The URI where the artifact data will be stored.
        artifact_store: The artifact store used to store this artifact.
    """
    self.uri = uri
    self._artifact_store = artifact_store
Functions
extract_metadata(model: tf_keras.Model) -> Dict[str, MetadataType]

Extract metadata from the given Model object.

Parameters:

Name Type Description Default
model Model

The Model object to extract metadata from.

required

Returns:

Type Description
Dict[str, MetadataType]

The extracted metadata as a dictionary.

Source code in src/zenml/integrations/tensorflow/materializers/keras_materializer.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def extract_metadata(
    self, model: tf_keras.Model
) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given `Model` object.

    Args:
        model: The `Model` object to extract metadata from.

    Returns:
        The extracted metadata as a dictionary.
    """
    return {
        "num_layers": len(model.layers),
        "num_params": count_params(model.weights),
        "num_trainable_params": count_params(model.trainable_weights),
    }
load(data_type: Type[Any]) -> tf_keras.Model

Reads and returns a Keras model after copying it to temporary path.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Model

A keras.Model model.

Source code in src/zenml/integrations/tensorflow/materializers/keras_materializer.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def load(self, data_type: Type[Any]) -> tf_keras.Model:
    """Reads and returns a Keras model after copying it to temporary path.

    Args:
        data_type: The type of the data to read.

    Returns:
        A keras.Model model.
    """
    with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
        # Copy from artifact store to temporary directory
        temp_model_file = os.path.join(temp_dir, self.MODEL_FILE_NAME)
        io_utils.copy_dir(self.uri, temp_dir)

        # Load the model from the temporary directory
        model = tf.keras.models.load_model(temp_model_file)

        return model
save(model: tf_keras.Model) -> None

Writes a keras model to the artifact store.

Parameters:

Name Type Description Default
model Model

A keras.Model model.

required
Source code in src/zenml/integrations/tensorflow/materializers/keras_materializer.py
60
61
62
63
64
65
66
67
68
69
def save(self, model: tf_keras.Model) -> None:
    """Writes a keras model to the artifact store.

    Args:
        model: A keras.Model model.
    """
    with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
        temp_model_file = os.path.join(temp_dir, self.MODEL_FILE_NAME)
        model.save(temp_model_file)
        io_utils.copy_dir(temp_dir, self.uri)
Modules
tf_dataset_materializer

Implementation of the TensorFlow dataset materializer.

Classes
TensorflowDatasetMaterializer(uri: str, artifact_store: Optional[BaseArtifactStore] = None)

Bases: BaseMaterializer

Materializer to read data to and from tf.data.Dataset.

Source code in src/zenml/materializers/base_materializer.py
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self, uri: str, artifact_store: Optional[BaseArtifactStore] = None
):
    """Initializes a materializer with the given URI.

    Args:
        uri: The URI where the artifact data will be stored.
        artifact_store: The artifact store used to store this artifact.
    """
    self.uri = uri
    self._artifact_store = artifact_store
Functions
extract_metadata(dataset: tf.data.Dataset) -> Dict[str, MetadataType]

Extract metadata from the given Dataset object.

Parameters:

Name Type Description Default
dataset Dataset

The Dataset object to extract metadata from.

required

Returns:

Type Description
Dict[str, MetadataType]

The extracted metadata as a dictionary.

Source code in src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
65
66
67
68
69
70
71
72
73
74
75
76
def extract_metadata(
    self, dataset: tf.data.Dataset
) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given `Dataset` object.

    Args:
        dataset: The `Dataset` object to extract metadata from.

    Returns:
        The extracted metadata as a dictionary.
    """
    return {"length": len(dataset)}
load(data_type: Type[Any]) -> Any

Reads data into tf.data.Dataset.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Any

A tf.data.Dataset object.

Source code in src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def load(self, data_type: Type[Any]) -> Any:
    """Reads data into tf.data.Dataset.

    Args:
        data_type: The type of the data to read.

    Returns:
        A tf.data.Dataset object.
    """
    with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
        io_utils.copy_dir(self.uri, temp_dir)
        path = os.path.join(temp_dir, DEFAULT_FILENAME)
        dataset = tf.data.Dataset.load(path)
        return dataset
save(dataset: tf.data.Dataset) -> None

Persists a tf.data.Dataset object.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to persist.

required
Source code in src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
52
53
54
55
56
57
58
59
60
61
62
63
def save(self, dataset: tf.data.Dataset) -> None:
    """Persists a tf.data.Dataset object.

    Args:
        dataset: The dataset to persist.
    """
    with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
        path = os.path.join(temp_dir, DEFAULT_FILENAME)
        tf.data.Dataset.save(
            dataset, path, compression=None, shard_func=None
        )
        io_utils.copy_dir(temp_dir, self.uri)
Modules