Skip to content

Integration

zenml.integrations.integration

Base and meta classes for ZenML integrations.

Integration

Base class for integration in ZenML.

Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
    """Base class for integration in ZenML."""

    NAME = "base_integration"

    REQUIREMENTS: List[str] = []
    APT_PACKAGES: List[str] = []
    REQUIREMENTS_IGNORED_ON_UNINSTALL: List[str] = []

    @classmethod
    def check_installation(cls) -> bool:
        """Method to check whether the required packages are installed.

        Returns:
            True if all required packages are installed, False otherwise.
        """
        for r in cls.get_requirements():
            try:
                # First check if the base package is installed
                dist = pkg_resources.get_distribution(r)

                # Next, check if the dependencies (including extras) are
                # installed
                deps: List[Requirement] = []

                _, extras = parse_requirement(r)
                if extras:
                    extra_list = extras[1:-1].split(",")
                    for extra in extra_list:
                        try:
                            requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                        except pkg_resources.UnknownExtra as e:
                            logger.debug(f"Unknown extra: {str(e)}")
                            return False
                        deps.extend(requirements)
                else:
                    deps = dist.requires()

                for ri in deps:
                    try:
                        # Remove the "extra == ..." part from the requirement string
                        cleaned_req = re.sub(
                            r"; extra == \"\w+\"", "", str(ri)
                        )
                        pkg_resources.get_distribution(cleaned_req)
                    except pkg_resources.DistributionNotFound as e:
                        logger.debug(
                            f"Unable to find required dependency "
                            f"'{e.req}' for requirement '{r}' "
                            f"necessary for integration '{cls.NAME}'."
                        )
                        return False
                    except pkg_resources.VersionConflict as e:
                        logger.debug(
                            f"Package version '{e.dist}' does not match "
                            f"version '{e.req}' required by '{r}' "
                            f"necessary for integration '{cls.NAME}'."
                        )
                        return False

            except pkg_resources.DistributionNotFound as e:
                logger.debug(
                    f"Unable to find required package '{e.req}' for "
                    f"integration {cls.NAME}."
                )
                return False
            except pkg_resources.VersionConflict as e:
                logger.debug(
                    f"Package version '{e.dist}' does not match version "
                    f"'{e.req}' necessary for integration {cls.NAME}."
                )
                return False

        logger.debug(
            f"Integration {cls.NAME} is installed correctly with "
            f"requirements {cls.get_requirements()}."
        )
        return True

    @classmethod
    def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
        """Method to get the requirements for the integration.

        Args:
            target_os: The target operating system to get the requirements for.

        Returns:
            A list of requirements.
        """
        return cls.REQUIREMENTS

    @classmethod
    def get_uninstall_requirements(
        cls, target_os: Optional[str] = None
    ) -> List[str]:
        """Method to get the uninstall requirements for the integration.

        Args:
            target_os: The target operating system to get the requirements for.

        Returns:
            A list of requirements.
        """
        ret = []
        for each in cls.get_requirements(target_os=target_os):
            is_ignored = False
            for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
                if each.startswith(ignored):
                    is_ignored = True
                    break
            if not is_ignored:
                ret.append(each)
        return ret

    @classmethod
    def activate(cls) -> None:
        """Abstract method to activate the integration."""

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Abstract method to declare new stack component flavors.

        Returns:
            A list of new stack component flavors.
        """
        return []

    @classmethod
    def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
        """Abstract method to declare new plugin flavors.

        Returns:
            A list of new plugin flavors.
        """
        return []

activate() classmethod

Abstract method to activate the integration.

Source code in zenml/integrations/integration.py
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""

check_installation() classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for r in cls.get_requirements():
        try:
            # First check if the base package is installed
            dist = pkg_resources.get_distribution(r)

            # Next, check if the dependencies (including extras) are
            # installed
            deps: List[Requirement] = []

            _, extras = parse_requirement(r)
            if extras:
                extra_list = extras[1:-1].split(",")
                for extra in extra_list:
                    try:
                        requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                    except pkg_resources.UnknownExtra as e:
                        logger.debug(f"Unknown extra: {str(e)}")
                        return False
                    deps.extend(requirements)
            else:
                deps = dist.requires()

            for ri in deps:
                try:
                    # Remove the "extra == ..." part from the requirement string
                    cleaned_req = re.sub(
                        r"; extra == \"\w+\"", "", str(ri)
                    )
                    pkg_resources.get_distribution(cleaned_req)
                except pkg_resources.DistributionNotFound as e:
                    logger.debug(
                        f"Unable to find required dependency "
                        f"'{e.req}' for requirement '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False
                except pkg_resources.VersionConflict as e:
                    logger.debug(
                        f"Package version '{e.dist}' does not match "
                        f"version '{e.req}' required by '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False

        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"Package version '{e.dist}' does not match version "
                f"'{e.req}' necessary for integration {cls.NAME}."
            )
            return False

    logger.debug(
        f"Integration {cls.NAME} is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True

flavors() classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

A list of new stack component flavors.

Source code in zenml/integrations/integration.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []

get_requirements(target_os=None) classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in zenml/integrations/integration.py
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS

get_uninstall_requirements(target_os=None) classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in zenml/integrations/integration.py
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret

plugin_flavors() classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in zenml/integrations/integration.py
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

IntegrationMeta (type)

Metaclass responsible for registering different Integration subclasses.

Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
    """Metaclass responsible for registering different Integration subclasses."""

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "IntegrationMeta":
        """Hook into creation of an Integration class.

        Args:
            name: The name of the class being created.
            bases: The base classes of the class being created.
            dct: The dictionary of attributes of the class being created.

        Returns:
            The newly created class.
        """
        cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
        if name != "Integration":
            integration_registry.register_integration(cls.NAME, cls)
        return cls

__new__(mcs, name, bases, dct) special staticmethod

Hook into creation of an Integration class.

Parameters:

Name Type Description Default
name str

The name of the class being created.

required
bases Tuple[Type[Any], ...]

The base classes of the class being created.

required
dct Dict[str, Any]

The dictionary of attributes of the class being created.

required

Returns:

Type Description
IntegrationMeta

The newly created class.

Source code in zenml/integrations/integration.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
    """Hook into creation of an Integration class.

    Args:
        name: The name of the class being created.
        bases: The base classes of the class being created.
        dct: The dictionary of attributes of the class being created.

    Returns:
        The newly created class.
    """
    cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
    if name != "Integration":
        integration_registry.register_integration(cls.NAME, cls)
    return cls