Pytorch
zenml.integrations.pytorch
special
Initialization of the PyTorch integration.
PytorchIntegration (Integration)
Definition of PyTorch integration for ZenML.
Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
"""Definition of PyTorch integration for ZenML."""
NAME = PYTORCH
REQUIREMENTS = ["torch"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
materializers
special
Initialization of the PyTorch Materializer.
pytorch_dataloader_materializer
Implementation of the PyTorch DataLoader materializer.
PyTorchDataLoaderMaterializer (BaseMaterializer)
Materializer to read/write PyTorch dataloaders.
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BaseMaterializer):
"""Materializer to read/write PyTorch dataloaders."""
ASSOCIATED_TYPES = (DataLoader,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
"""Reads and returns a PyTorch dataloader.
Args:
data_type: The type of the dataloader to load.
Returns:
A loaded PyTorch dataloader.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(DataLoader[Any], torch.load(f))
def handle_return(self, dataloader: DataLoader[Any]) -> None:
"""Writes a PyTorch dataloader.
Args:
dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
"""
super().handle_return(dataloader)
# Save entire dataloader to artifact directory
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(dataloader, f)
handle_input(self, data_type)
Reads and returns a PyTorch dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the dataloader to load. |
required |
Returns:
Type | Description |
---|---|
torch.utils.data.dataloader.DataLoader[Any] |
A loaded PyTorch dataloader. |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
"""Reads and returns a PyTorch dataloader.
Args:
data_type: The type of the dataloader to load.
Returns:
A loaded PyTorch dataloader.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(DataLoader[Any], torch.load(f))
handle_return(self, dataloader)
Writes a PyTorch dataloader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataloader |
torch.utils.data.dataloader.DataLoader[Any] |
A torch.utils.DataLoader or a dict to pass into dataloader.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_return(self, dataloader: DataLoader[Any]) -> None:
"""Writes a PyTorch dataloader.
Args:
dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
"""
super().handle_return(dataloader)
# Save entire dataloader to artifact directory
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(dataloader, f)
pytorch_module_materializer
Implementation of the PyTorch Module materializer.
PyTorchModuleMaterializer (BaseMaterializer)
Materializer to read/write Pytorch models.
Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BaseMaterializer):
"""Materializer to read/write Pytorch models.
Inspired by the guide:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
"""
ASSOCIATED_TYPES = (Module,)
ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)
def handle_input(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Args:
data_type: The type of the model to load.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(Module, torch.load(f))
def handle_return(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(model, f)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f)
handle_input(self, data_type)
Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the model to load. |
required |
Returns:
Type | Description |
---|---|
Module |
A loaded pytorch model. |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_input(self, data_type: Type[Any]) -> Module:
"""Reads and returns a PyTorch model.
Only loads the model, not the checkpoint.
Args:
data_type: The type of the model to load.
Returns:
A loaded pytorch model.
"""
super().handle_input(data_type)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
) as f:
return cast(Module, torch.load(f))
handle_return(self, model)
Writes a PyTorch model, as a model and a checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module |
A torch.nn.Module or a dict to pass into model.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_return(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
super().handle_return(model)
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
with fileio.open(
os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
) as f:
torch.save(model, f)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f)