Skip to content

Pytorch

zenml.integrations.pytorch

Initialization of the PyTorch integration.

Attributes

PYTORCH = 'pytorch' module-attribute

Classes

Integration

Base class for integration in ZenML.

Functions
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
140
141
142
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for requirement in cls.get_requirements():
        parsed_requirement = Requirement(requirement)

        if not requirement_installed(parsed_requirement):
            logger.debug(
                "Requirement '%s' for integration '%s' is not installed "
                "or installed with the wrong version.",
                requirement,
                cls.NAME,
            )
            return False

        dependencies = get_dependencies(parsed_requirement)

        for dependency in dependencies:
            if not requirement_installed(dependency):
                logger.debug(
                    "Requirement '%s' for integration '%s' is not "
                    "installed or installed with the wrong version.",
                    dependency,
                    cls.NAME,
                )
                return False

    logger.debug(
        f"Integration '{cls.NAME}' is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
144
145
146
147
148
149
150
151
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@classmethod
def get_requirements(
    cls,
    target_os: Optional[str] = None,
    python_version: Optional[str] = None,
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret
plugin_flavors() -> List[Type[BasePluginFlavor]] classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in src/zenml/integrations/integration.py
153
154
155
156
157
158
159
160
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

PytorchIntegration

Bases: Integration

Definition of PyTorch integration for ZenML.

Functions
activate() -> None classmethod

Activates the integration.

Source code in src/zenml/integrations/pytorch/__init__.py
28
29
30
31
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch import materializers  # noqa

Modules

materializers

Initialization of the PyTorch Materializer.

Classes
Modules
base_pytorch_materializer

Implementation of the PyTorch DataLoader materializer.

Classes
BasePyTorchMaterializer(uri: str, artifact_store: Optional[BaseArtifactStore] = None)

Bases: BaseMaterializer

Base class for PyTorch materializers.

Source code in src/zenml/materializers/base_materializer.py
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self, uri: str, artifact_store: Optional[BaseArtifactStore] = None
):
    """Initializes a materializer with the given URI.

    Args:
        uri: The URI where the artifact data will be stored.
        artifact_store: The artifact store used to store this artifact.
    """
    self.uri = uri
    self._artifact_store = artifact_store
Functions
load(data_type: Type[Any]) -> Any

Uses torch.load to load a PyTorch object.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the object to load.

required

Returns:

Type Description
Any

The loaded PyTorch object.

Source code in src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def load(self, data_type: Type[Any]) -> Any:
    """Uses `torch.load` to load a PyTorch object.

    Args:
        data_type: The type of the object to load.

    Returns:
        The loaded PyTorch object.
    """
    with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
        # NOTE (security): The `torch.load` function uses `pickle` as
        # the default unpickler, which is NOT secure. This materializer
        # is intended for use with trusted data sources.
        return torch.load(f, weights_only=False)  # nosec
save(obj: Any) -> None

Uses torch.save to save a PyTorch object.

Parameters:

Name Type Description Default
obj Any

The PyTorch object to save.

required
Source code in src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
49
50
51
52
53
54
55
56
57
58
59
def save(self, obj: Any) -> None:
    """Uses `torch.save` to save a PyTorch object.

    Args:
        obj: The PyTorch object to save.
    """
    with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
        # NOTE (security): The `torch.save` function uses `cloudpickle` as
        # the default unpickler, which is NOT secure. This materializer
        # is intended for use with trusted data sources.
        torch.save(obj, f, pickle_module=cloudpickle)  # nosec
Modules
pytorch_dataloader_materializer

Implementation of the PyTorch DataLoader materializer.

Classes
PyTorchDataLoaderMaterializer(uri: str, artifact_store: Optional[BaseArtifactStore] = None)

Bases: BasePyTorchMaterializer

Materializer to read/write PyTorch dataloaders and datasets.

Source code in src/zenml/materializers/base_materializer.py
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self, uri: str, artifact_store: Optional[BaseArtifactStore] = None
):
    """Initializes a materializer with the given URI.

    Args:
        uri: The URI where the artifact data will be stored.
        artifact_store: The artifact store used to store this artifact.
    """
    self.uri = uri
    self._artifact_store = artifact_store
Functions
extract_metadata(dataloader: Any) -> Dict[str, MetadataType]

Extract metadata from the given dataloader or dataset.

Parameters:

Name Type Description Default
dataloader Any

The dataloader or dataset to extract metadata from.

required

Returns:

Type Description
Dict[str, MetadataType]

The extracted metadata as a dictionary.

Source code in src/zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given dataloader or dataset.

    Args:
        dataloader: The dataloader or dataset to extract metadata from.

    Returns:
        The extracted metadata as a dictionary.
    """
    metadata: Dict[str, "MetadataType"] = {}
    if isinstance(dataloader, DataLoader):
        if hasattr(dataloader.dataset, "__len__"):
            metadata["num_samples"] = len(dataloader.dataset)
        if dataloader.batch_size:
            metadata["batch_size"] = dataloader.batch_size
        metadata["num_batches"] = len(dataloader)
    elif isinstance(dataloader, Dataset):
        if hasattr(dataloader, "__len__"):
            metadata["num_samples"] = len(dataloader)
    return metadata
pytorch_module_materializer

Implementation of the PyTorch Module materializer.

Classes
PyTorchModuleMaterializer(uri: str, artifact_store: Optional[BaseArtifactStore] = None)

Bases: BasePyTorchMaterializer

Materializer to read/write Pytorch models.

Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html

Source code in src/zenml/materializers/base_materializer.py
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self, uri: str, artifact_store: Optional[BaseArtifactStore] = None
):
    """Initializes a materializer with the given URI.

    Args:
        uri: The URI where the artifact data will be stored.
        artifact_store: The artifact store used to store this artifact.
    """
    self.uri = uri
    self._artifact_store = artifact_store
Functions
extract_metadata(model: Module) -> Dict[str, MetadataType]

Extract metadata from the given Model object.

Parameters:

Name Type Description Default
model Module

The Model object to extract metadata from.

required

Returns:

Type Description
Dict[str, MetadataType]

The extracted metadata as a dictionary.

Source code in src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
69
70
71
72
73
74
75
76
77
78
def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given `Model` object.

    Args:
        model: The `Model` object to extract metadata from.

    Returns:
        The extracted metadata as a dictionary.
    """
    return {**count_module_params(model)}
save(model: Module) -> None

Writes a PyTorch model, as a model and a checkpoint.

Parameters:

Name Type Description Default
model Module

A torch.nn.Module or a dict to pass into model.save

required
Source code in src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def save(self, model: Module) -> None:
    """Writes a PyTorch model, as a model and a checkpoint.

    Args:
        model: A torch.nn.Module or a dict to pass into model.save
    """
    # Save entire model to artifact directory, This is the default behavior
    # for loading model in development phase (training, evaluation)
    super().save(model)

    # Also save model checkpoint to artifact directory,
    # This is the default behavior for loading model in production phase (inference)
    if isinstance(model, Module):
        with fileio.open(
            os.path.join(self.uri, CHECKPOINT_FILENAME), "wb"
        ) as f:
            # NOTE (security): The `torch.save` function uses `cloudpickle` as
            # the default unpickler, which is NOT secure. This materializer
            # is intended for use with trusted data sources.
            torch.save(model.state_dict(), f, pickle_module=cloudpickle)  # nosec
Functions Modules

utils

PyTorch utils.

Functions
count_module_params(module: torch.nn.Module) -> Dict[str, int]

Get the total and trainable parameters of a module.

Parameters:

Name Type Description Default
module Module

The module to get the parameters of.

required

Returns:

Type Description
Dict[str, int]

A dictionary with the total and trainable parameters.

Source code in src/zenml/integrations/pytorch/utils.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def count_module_params(module: torch.nn.Module) -> Dict[str, int]:
    """Get the total and trainable parameters of a module.

    Args:
        module: The module to get the parameters of.

    Returns:
        A dictionary with the total and trainable parameters.
    """
    total_params = sum([param.numel() for param in module.parameters()])
    trainable_params = sum(
        [param.numel() for param in module.parameters() if param.requires_grad]
    )
    return {
        "num_params": total_params,
        "num_trainable_params": trainable_params,
    }