Pytorch Lightning
zenml.integrations.pytorch_lightning
special
Initialization of the PyTorch Lightning integration.
PytorchLightningIntegration (Integration)
Definition of PyTorch Lightning integration for ZenML.
Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
"""Definition of PyTorch Lightning integration for ZenML."""
NAME = PYTORCH_L
REQUIREMENTS = ["pytorch_lightning"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch_lightning import materializers # noqa
materializers
special
Initialization of the PyTorch Lightning Materializer.
pytorch_lightning_materializer
Implementation of the PyTorch Lightning Materializer.
PyTorchLightningMaterializer (BaseMaterializer)
Materializer to read/write PyTorch models.
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
"""Materializer to read/write PyTorch models."""
ASSOCIATED_TYPES = (Module,)
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.MODEL
def load(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch Lightning model.
Args:
data_type: The type of the model to load.
Returns:
A PyTorch Lightning model object.
"""
super().load(data_type)
with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "rb") as f:
return cast(Module, torch.load(f))
def save(self, model: Module) -> None:
"""Writes a PyTorch Lightning model.
Args:
model: The PyTorch Lightning model to save.
"""
super().save(model)
with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "wb") as f:
torch.save(model, f)
load(self, data_type)
Reads and returns a PyTorch Lightning model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to load. |
required |
Returns:
Type | Description |
---|---|
Module |
A PyTorch Lightning model object. |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def load(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch Lightning model.
Args:
data_type: The type of the model to load.
Returns:
A PyTorch Lightning model object.
"""
super().load(data_type)
with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "rb") as f:
return cast(Module, torch.load(f))
save(self, model)
Writes a PyTorch Lightning model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module |
The PyTorch Lightning model to save. |
required |
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def save(self, model: Module) -> None:
"""Writes a PyTorch Lightning model.
Args:
model: The PyTorch Lightning model to save.
"""
super().save(model)
with fileio.open(os.path.join(self.uri, CHECKPOINT_NAME), "wb") as f:
torch.save(model, f)