Pytorch
        zenml.integrations.pytorch
  
      special
  
    Initialization of the PyTorch integration.
        
PytorchIntegration            (Integration)
        
    Definition of PyTorch integration for ZenML.
Source code in zenml/integrations/pytorch/__init__.py
          class PytorchIntegration(Integration):
    """Definition of PyTorch integration for ZenML."""
    NAME = PYTORCH
    REQUIREMENTS = ["torch"]
    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch import materializers  # noqa
activate()
  
      classmethod
  
    Activates the integration.
Source code in zenml/integrations/pytorch/__init__.py
          @classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch import materializers  # noqa
        materializers
  
      special
  
    Initialization of the PyTorch Materializer.
        base_pytorch_materializer
    Implementation of the PyTorch DataLoader materializer.
        
BasePyTorchMaterializer            (BaseMaterializer)
        
    Base class for PyTorch materializers.
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
          class BasePyTorchMaterializer(BaseMaterializer):
    """Base class for PyTorch materializers."""
    FILENAME: ClassVar[str] = DEFAULT_FILENAME
    SKIP_REGISTRATION: ClassVar[bool] = True
    def load(self, data_type: Type[Any]) -> Any:
        """Uses `torch.load` to load a PyTorch object.
        Args:
            data_type: The type of the object to load.
        Returns:
            The loaded PyTorch object.
        """
        with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
            return torch.load(f)
    def save(self, obj: Any) -> None:
        """Uses `torch.save` to save a PyTorch object.
        Args:
            obj: The PyTorch object to save.
        """
        with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
            torch.save(obj, f, pickle_module=cloudpickle)
load(self, data_type)
    Uses torch.load to load a PyTorch object.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| data_type | Type[Any] | The type of the object to load. | required | 
Returns:
| Type | Description | 
|---|---|
| Any | The loaded PyTorch object. | 
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
          def load(self, data_type: Type[Any]) -> Any:
    """Uses `torch.load` to load a PyTorch object.
    Args:
        data_type: The type of the object to load.
    Returns:
        The loaded PyTorch object.
    """
    with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
        return torch.load(f)
save(self, obj)
    Uses torch.save to save a PyTorch object.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| obj | Any | The PyTorch object to save. | required | 
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
          def save(self, obj: Any) -> None:
    """Uses `torch.save` to save a PyTorch object.
    Args:
        obj: The PyTorch object to save.
    """
    with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
        torch.save(obj, f, pickle_module=cloudpickle)
        
BasePyTorchMaterliazer            (BaseMaterializer)
        
    Base class for PyTorch materializers.
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
          class BasePyTorchMaterializer(BaseMaterializer):
    """Base class for PyTorch materializers."""
    FILENAME: ClassVar[str] = DEFAULT_FILENAME
    SKIP_REGISTRATION: ClassVar[bool] = True
    def load(self, data_type: Type[Any]) -> Any:
        """Uses `torch.load` to load a PyTorch object.
        Args:
            data_type: The type of the object to load.
        Returns:
            The loaded PyTorch object.
        """
        with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
            return torch.load(f)
    def save(self, obj: Any) -> None:
        """Uses `torch.save` to save a PyTorch object.
        Args:
            obj: The PyTorch object to save.
        """
        with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
            torch.save(obj, f, pickle_module=cloudpickle)
load(self, data_type)
    Uses torch.load to load a PyTorch object.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| data_type | Type[Any] | The type of the object to load. | required | 
Returns:
| Type | Description | 
|---|---|
| Any | The loaded PyTorch object. | 
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
          def load(self, data_type: Type[Any]) -> Any:
    """Uses `torch.load` to load a PyTorch object.
    Args:
        data_type: The type of the object to load.
    Returns:
        The loaded PyTorch object.
    """
    with fileio.open(os.path.join(self.uri, self.FILENAME), "rb") as f:
        return torch.load(f)
save(self, obj)
    Uses torch.save to save a PyTorch object.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| obj | Any | The PyTorch object to save. | required | 
Source code in zenml/integrations/pytorch/materializers/base_pytorch_materializer.py
          def save(self, obj: Any) -> None:
    """Uses `torch.save` to save a PyTorch object.
    Args:
        obj: The PyTorch object to save.
    """
    with fileio.open(os.path.join(self.uri, self.FILENAME), "wb") as f:
        torch.save(obj, f, pickle_module=cloudpickle)
        pytorch_dataloader_materializer
    Implementation of the PyTorch DataLoader materializer.
        
PyTorchDataLoaderMaterializer            (BasePyTorchMaterializer)
        
    Materializer to read/write PyTorch dataloaders and datasets.
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
          class PyTorchDataLoaderMaterializer(BasePyTorchMaterializer):
    """Materializer to read/write PyTorch dataloaders and datasets."""
    ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (DataLoader, Dataset)
    ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA
    FILENAME: ClassVar[str] = DEFAULT_FILENAME
    def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]:
        """Extract metadata from the given dataloader or dataset.
        Args:
            dataloader: The dataloader or dataset to extract metadata from.
        Returns:
            The extracted metadata as a dictionary.
        """
        metadata: Dict[str, "MetadataType"] = {}
        if isinstance(dataloader, DataLoader):
            if hasattr(dataloader.dataset, "__len__"):
                metadata["num_samples"] = len(dataloader.dataset)
            if dataloader.batch_size:
                metadata["batch_size"] = dataloader.batch_size
            metadata["num_batches"] = len(dataloader)
        elif isinstance(dataloader, Dataset):
            if hasattr(dataloader, "__len__"):
                metadata["num_samples"] = len(dataloader)
        return metadata
extract_metadata(self, dataloader)
    Extract metadata from the given dataloader or dataset.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| dataloader | Any | The dataloader or dataset to extract metadata from. | required | 
Returns:
| Type | Description | 
|---|---|
| Dict[str, MetadataType] | The extracted metadata as a dictionary. | 
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
          def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given dataloader or dataset.
    Args:
        dataloader: The dataloader or dataset to extract metadata from.
    Returns:
        The extracted metadata as a dictionary.
    """
    metadata: Dict[str, "MetadataType"] = {}
    if isinstance(dataloader, DataLoader):
        if hasattr(dataloader.dataset, "__len__"):
            metadata["num_samples"] = len(dataloader.dataset)
        if dataloader.batch_size:
            metadata["batch_size"] = dataloader.batch_size
        metadata["num_batches"] = len(dataloader)
    elif isinstance(dataloader, Dataset):
        if hasattr(dataloader, "__len__"):
            metadata["num_samples"] = len(dataloader)
    return metadata
        pytorch_module_materializer
    Implementation of the PyTorch Module materializer.
        
PyTorchModuleMaterializer            (BasePyTorchMaterializer)
        
    Materializer to read/write Pytorch models.
Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
          class PyTorchModuleMaterializer(BasePyTorchMaterializer):
    """Materializer to read/write Pytorch models.
    Inspired by the guide:
    https://pytorch.org/tutorials/beginner/saving_loading_models.html
    """
    ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,)
    ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL
    FILENAME: ClassVar[str] = DEFAULT_FILENAME
    def save(self, model: Module) -> None:
        """Writes a PyTorch model, as a model and a checkpoint.
        Args:
            model: A torch.nn.Module or a dict to pass into model.save
        """
        # Save entire model to artifact directory, This is the default behavior
        # for loading model in development phase (training, evaluation)
        super().save(model)
        # Also save model checkpoint to artifact directory,
        # This is the default behavior for loading model in production phase (inference)
        if isinstance(model, Module):
            with fileio.open(
                os.path.join(self.uri, CHECKPOINT_FILENAME), "wb"
            ) as f:
                torch.save(model.state_dict(), f, pickle_module=cloudpickle)
    def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
        """Extract metadata from the given `Model` object.
        Args:
            model: The `Model` object to extract metadata from.
        Returns:
            The extracted metadata as a dictionary.
        """
        return {**count_module_params(model)}
extract_metadata(self, model)
    Extract metadata from the given Model object.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | torch.nn.Module | The  | required | 
Returns:
| Type | Description | 
|---|---|
| Dict[str, MetadataType] | The extracted metadata as a dictionary. | 
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
          def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
    """Extract metadata from the given `Model` object.
    Args:
        model: The `Model` object to extract metadata from.
    Returns:
        The extracted metadata as a dictionary.
    """
    return {**count_module_params(model)}
save(self, model)
    Writes a PyTorch model, as a model and a checkpoint.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | torch.nn.Module | A torch.nn.Module or a dict to pass into model.save | required | 
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
          def save(self, model: Module) -> None:
    """Writes a PyTorch model, as a model and a checkpoint.
    Args:
        model: A torch.nn.Module or a dict to pass into model.save
    """
    # Save entire model to artifact directory, This is the default behavior
    # for loading model in development phase (training, evaluation)
    super().save(model)
    # Also save model checkpoint to artifact directory,
    # This is the default behavior for loading model in production phase (inference)
    if isinstance(model, Module):
        with fileio.open(
            os.path.join(self.uri, CHECKPOINT_FILENAME), "wb"
        ) as f:
            torch.save(model.state_dict(), f, pickle_module=cloudpickle)
        utils
    PyTorch utils.
count_module_params(module)
    Get the total and trainable parameters of a module.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| module | torch.nn.Module | The module to get the parameters of. | required | 
Returns:
| Type | Description | 
|---|---|
| Dict[str, int] | A dictionary with the total and trainable parameters. | 
Source code in zenml/integrations/pytorch/utils.py
          def count_module_params(module: torch.nn.Module) -> Dict[str, int]:
    """Get the total and trainable parameters of a module.
    Args:
        module: The module to get the parameters of.
    Returns:
        A dictionary with the total and trainable parameters.
    """
    total_params = sum([param.numel() for param in module.parameters()])
    trainable_params = sum(
        [param.numel() for param in module.parameters() if param.requires_grad]
    )
    return {
        "num_params": total_params,
        "num_trainable_params": trainable_params,
    }