Pytorch
zenml.integrations.pytorch
special
Initialization of the PyTorch integration.
PytorchIntegration (Integration)
Definition of PyTorch integration for ZenML.
Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
"""Definition of PyTorch integration for ZenML."""
NAME = PYTORCH
REQUIREMENTS = ["torch"]
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
activate()
classmethod
Activates the integration.
Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.pytorch import materializers # noqa
materializers
special
Initialization of the PyTorch Materializer.
base_pytorch_materializer
Implementation of the PyTorch DataLoader materializer.
BasePyTorchMaterializer (BaseMaterializer)
Base class for PyTorch materializers.
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
class BasePyTorchMaterializer(BaseMaterializer):
"""Base class for PyTorch materializers."""
FILENAME: ClassVar[str] = DEFAULT_FILENAME
SKIP_REGISTRATION: ClassVar[bool] = True
def load(self, data_type: Type[Any]) -> Any:
"""Uses `torch.load` to load a PyTorch object.
Args:
data_type: The type of the object to load.
Returns:
The loaded PyTorch object.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
return torch.load(f)
def save(self, obj: Any) -> None:
"""Uses `torch.save` to save a PyTorch object.
Args:
obj: The PyTorch object to save.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
torch.save(obj, f, pickle_module=cloudpickle)
load(self, data_type)
Uses torch.load
to load a PyTorch object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the object to load. |
required |
Returns:
Type | Description |
---|---|
Any |
The loaded PyTorch object. |
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
def load(self, data_type: Type[Any]) -> Any:
"""Uses `torch.load` to load a PyTorch object.
Args:
data_type: The type of the object to load.
Returns:
The loaded PyTorch object.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
return torch.load(f)
save(self, obj)
Uses torch.save
to save a PyTorch object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The PyTorch object to save. |
required |
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
def save(self, obj: Any) -> None:
"""Uses `torch.save` to save a PyTorch object.
Args:
obj: The PyTorch object to save.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
torch.save(obj, f, pickle_module=cloudpickle)
BasePyTorchMaterliazer (BaseMaterializer)
Base class for PyTorch materializers.
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
class BasePyTorchMaterializer(BaseMaterializer):
"""Base class for PyTorch materializers."""
FILENAME: ClassVar[str] = DEFAULT_FILENAME
SKIP_REGISTRATION: ClassVar[bool] = True
def load(self, data_type: Type[Any]) -> Any:
"""Uses `torch.load` to load a PyTorch object.
Args:
data_type: The type of the object to load.
Returns:
The loaded PyTorch object.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
return torch.load(f)
def save(self, obj: Any) -> None:
"""Uses `torch.save` to save a PyTorch object.
Args:
obj: The PyTorch object to save.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
torch.save(obj, f, pickle_module=cloudpickle)
load(self, data_type)
Uses torch.load
to load a PyTorch object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the object to load. |
required |
Returns:
Type | Description |
---|---|
Any |
The loaded PyTorch object. |
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
def load(self, data_type: Type[Any]) -> Any:
"""Uses `torch.load` to load a PyTorch object.
Args:
data_type: The type of the object to load.
Returns:
The loaded PyTorch object.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
return torch.load(f)
save(self, obj)
Uses torch.save
to save a PyTorch object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
The PyTorch object to save. |
required |
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
def save(self, obj: Any) -> None:
"""Uses `torch.save` to save a PyTorch object.
Args:
obj: The PyTorch object to save.
"""
with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
torch.save(obj, f, pickle_module=cloudpickle)
pytorch_dataloader_materializer
Implementation of the PyTorch DataLoader materializer.
PyTorchDataLoaderMaterializer (BasePyTorchMaterializer)
Materializer to read/write PyTorch dataloaders and datasets.
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BasePyTorchMaterializer):
"""Materializer to read/write PyTorch dataloaders and datasets."""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (DataLoader, Dataset)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
FILENAME: ClassVar[str] = DEFAULT_FILENAME
def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given dataloader or dataset.
Args:
dataloader: The dataloader or dataset to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
metadata: Dict[str, "MetadataType"] = {}
if isinstance(dataloader, DataLoader):
if hasattr(dataloader.dataset, "__len__"):
metadata["num_samples"] = len(dataloader.dataset)
if dataloader.batch_size:
metadata["batch_size"] = dataloader.batch_size
metadata["num_batches"] = len(dataloader)
elif isinstance(dataloader, Dataset):
if hasattr(dataloader, "__len__"):
metadata["num_samples"] = len(dataloader)
return metadata
extract_metadata(self, dataloader)
Extract metadata from the given dataloader or dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataloader |
Any |
The dataloader or dataset to extract metadata from. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given dataloader or dataset.
Args:
dataloader: The dataloader or dataset to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
metadata: Dict[str, "MetadataType"] = {}
if isinstance(dataloader, DataLoader):
if hasattr(dataloader.dataset, "__len__"):
metadata["num_samples"] = len(dataloader.dataset)
if dataloader.batch_size:
metadata["batch_size"] = dataloader.batch_size
metadata["num_batches"] = len(dataloader)
elif isinstance(dataloader, Dataset):
if hasattr(dataloader, "__len__"):
metadata["num_samples"] = len(dataloader)
return metadata
pytorch_module_materializer
Implementation of the PyTorch Module materializer.
PyTorchModuleMaterializer (BasePyTorchMaterializer)
Materializer to read/write Pytorch models.
Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BasePyTorchMaterializer):
"""Materializer to read/write Pytorch models.
Inspired by the guide:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
"""
ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL
FILENAME: ClassVar[str] = DEFAULT_FILENAME
def save(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
super().save(model)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f, pickle_module=cloudpickle)
def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given `Model` object.
Args:
model: The `Model` object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
return {**count_module_params(model)}
extract_metadata(self, model)
Extract metadata from the given Model
object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
The |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
The extracted metadata as a dictionary. |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given `Model` object.
Args:
model: The `Model` object to extract metadata from.
Returns:
The extracted metadata as a dictionary.
"""
return {**count_module_params(model)}
save(self, model)
Writes a PyTorch model, as a model and a checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module |
A torch.nn.Module or a dict to pass into model.save |
required |
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def save(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
Args:
model: A torch.nn.Module or a dict to pass into model.save
"""
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
super().save(model)
# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.uri, CHECKPOINT_FILENAME), "wb"
) as f:
torch.save(model.state_dict(), f, pickle_module=cloudpickle)
utils
PyTorch utils.
count_module_params(module)
Get the total and trainable parameters of a module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
torch.nn.Module |
The module to get the parameters of. |
required |
Returns:
Type | Description |
---|---|
Dict[str, int] |
A dictionary with the total and trainable parameters. |
Source code in zenml/integrations/pytorch/utils.py
def count_module_params(module: torch.nn.Module) -> Dict[str, int]:
"""Get the total and trainable parameters of a module.
Args:
module: The module to get the parameters of.
Returns:
A dictionary with the total and trainable parameters.
"""
total_params = sum([param.numel() for param in module.parameters()])
trainable_params = sum(
[param.numel() for param in module.parameters() if param.requires_grad]
)
return {
"num_params": total_params,
"num_trainable_params": trainable_params,
}