Skip to content

Pytorch Lightning

zenml.integrations.pytorch_lightning special

Initialization of the PyTorch Lightning integration.

PytorchLightningIntegration (Integration)

Definition of PyTorch Lightning integration for ZenML.

Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
    """Definition of PyTorch Lightning integration for ZenML."""

    NAME = PYTORCH_L
    REQUIREMENTS = ["pytorch_lightning"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch_lightning import materializers  # noqa

activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch_lightning import materializers  # noqa

materializers special

Initialization of the PyTorch Lightning Materializer.

pytorch_lightning_materializer

Implementation of the PyTorch Lightning Materializer.

PyTorchLightningMaterializer (BaseMaterializer)

Materializer to read/write PyTorch models.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
    """Materializer to read/write PyTorch models."""

    ASSOCIATED_TYPES = (Module,)
    ASSOCIATED_ARTIFACT_TYPE = ArtifactType.MODEL

    def load(self, data_type: Type[Any]) -> Module:
        """Reads and returns a PyTorch Lightning model.

        Args:
            data_type: The type of the model to load.

        Returns:
            A PyTorch Lightning model object.
        """
        super().load(data_type)
        with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "rb") as f:
            return cast(Module, torch.load(f))

    def save(self, model: Module) -> None:
        """Writes a PyTorch Lightning model.

        Args:
            model: The PyTorch Lightning model to save.
        """
        super().save(model)
        with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "wb") as f:
            torch.save(model, f)
load(self, data_type)

Reads and returns a PyTorch Lightning model.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to load.

required

Returns:

Type Description
Module

A PyTorch Lightning model object.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def load(self, data_type: Type[Any]) -> Module:
    """Reads and returns a PyTorch Lightning model.

    Args:
        data_type: The type of the model to load.

    Returns:
        A PyTorch Lightning model object.
    """
    super().load(data_type)
    with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "rb") as f:
        return cast(Module, torch.load(f))
save(self, model)

Writes a PyTorch Lightning model.

Parameters:

Name Type Description Default
model Module

The PyTorch Lightning model to save.

required
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def save(self, model: Module) -> None:
    """Writes a PyTorch Lightning model.

    Args:
        model: The PyTorch Lightning model to save.
    """
    super().save(model)
    with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "wb") as f:
        torch.save(model, f)