Skip to content

Openai

zenml.integrations.openai

Initialization of the OpenAI integration.

Attributes

OPEN_AI = 'openai' module-attribute

Classes

Integration

Base class for integration in ZenML.

Functions
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
175
176
177
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for r in cls.get_requirements():
        try:
            # First check if the base package is installed
            dist = pkg_resources.get_distribution(r)

            # Next, check if the dependencies (including extras) are
            # installed
            deps: List[Requirement] = []

            _, extras = parse_requirement(r)
            if extras:
                extra_list = extras[1:-1].split(",")
                for extra in extra_list:
                    try:
                        requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                    except pkg_resources.UnknownExtra as e:
                        logger.debug(f"Unknown extra: {str(e)}")
                        return False
                    deps.extend(requirements)
            else:
                deps = dist.requires()

            for ri in deps:
                try:
                    # Remove the "extra == ..." part from the requirement string
                    cleaned_req = re.sub(
                        r"; extra == \"\w+\"", "", str(ri)
                    )
                    pkg_resources.get_distribution(cleaned_req)
                except pkg_resources.DistributionNotFound as e:
                    logger.debug(
                        f"Unable to find required dependency "
                        f"'{e.req}' for requirement '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False
                except pkg_resources.VersionConflict as e:
                    logger.debug(
                        f"Package version '{e.dist}' does not match "
                        f"version '{e.req}' required by '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False

        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"Package version '{e.dist}' does not match version "
                f"'{e.req}' necessary for integration {cls.NAME}."
            )
            return False

    logger.debug(
        f"Integration {cls.NAME} is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
179
180
181
182
183
184
185
186
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@classmethod
def get_requirements(
    cls,
    target_os: Optional[str] = None,
    python_version: Optional[str] = None,
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret
plugin_flavors() -> List[Type[BasePluginFlavor]] classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in src/zenml/integrations/integration.py
188
189
190
191
192
193
194
195
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

OpenAIIntegration

Bases: Integration

Definition of OpenAI integration for ZenML.

Modules

hooks

Initialization of the OpenAI hooks module.

Functions
Modules
open_ai_failure_hook

Functionality for OpenAI standard hooks.

Classes Functions
openai_alerter_failure_hook_helper(exception: BaseException, model_name: str) -> None

Standard failure hook that sends a message to an Alerter.

Your OpenAI API key must be stored in the secret store under the name "openai" and with the key "api_key".

Parameters:

Name Type Description Default
exception BaseException

The exception that was raised.

required
model_name str

The OpenAI model to use for the chatbot.

required

This implementation uses the OpenAI v1 SDK with automatic retries and backoff.

Source code in src/zenml/integrations/openai/hooks/open_ai_failure_hook.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def openai_alerter_failure_hook_helper(
    exception: BaseException,
    model_name: str,
) -> None:
    """Standard failure hook that sends a message to an Alerter.

    Your OpenAI API key must be stored in the secret store under the name
    "openai" and with the key "api_key".

    Args:
        exception: The exception that was raised.
        model_name: The OpenAI model to use for the chatbot.

    This implementation uses the OpenAI v1 SDK with automatic retries and backoff.
    """
    client = Client()
    context = get_step_context()

    # get the api_key from the secret store
    try:
        openai_secret = client.get_secret(
            "openai", allow_partial_name_match=False
        )
        openai_api_key: Optional[str] = openai_secret.secret_values.get(
            "api_key"
        )
    except (KeyError, NotImplementedError):
        openai_api_key = None

    alerter = client.active_stack.alerter
    if alerter and openai_api_key:
        # Capture rich traceback
        output_captured = io.StringIO()
        original_stdout = sys.stdout
        sys.stdout = output_captured
        console = Console()
        console.print_exception(show_locals=False)

        sys.stdout = original_stdout
        rich_traceback = output_captured.getvalue()

        # Initialize OpenAI client with timeout and retry settings
        openai_client = OpenAI(
            api_key=openai_api_key,
            max_retries=3,  # Will retry 3 times with exponential backoff
            timeout=60.0,  # 60 second timeout
        )

        # Create chat completion using the new client pattern
        response = openai_client.chat.completions.create(
            model=model_name,
            messages=[
                {
                    "role": "user",
                    "content": f"This is an error message (following an exception of type '{type(exception)}') "
                    f"I encountered while executing a ZenML step. Please suggest ways I might fix the problem. "
                    f"Feel free to give code snippets as examples, and note that your response will be piped "
                    f"to a Slack bot so make sure the formatting is appropriate: {exception} -- {rich_traceback}. "
                    f"Thank you!",
                }
            ],
        )

        suggestion = response.choices[0].message.content

        # Format the alert message
        message = "\n".join(
            [
                "*Failure Hook Notification! Step failed!*",
                "",
                f"Run name: `{context.pipeline_run.name}`",
                f"Step name: `{context.step_run.name}`",
                f"Parameters: `{context.step_run.config.parameters}`",
                f"Exception: `({type(exception)}) {exception}`",
                "",
                f"*OpenAI ChatGPT's suggestion (model = `{model_name}`) on how to fix it:*\n `{suggestion}`",
            ]
        )

        alerter.post(message)
    elif not openai_api_key:
        logger.warning(
            "Specified OpenAI failure hook but no OpenAI API key found. Skipping..."
        )
    else:
        logger.warning(
            "Specified OpenAI failure hook but no alerter configured in the stack. Skipping..."
        )
openai_chatgpt_alerter_failure_hook(exception: BaseException) -> None

Alerter hook that uses the OpenAI ChatGPT model.

Parameters:

Name Type Description Default
exception BaseException

The exception that was raised.

required
Source code in src/zenml/integrations/openai/hooks/open_ai_failure_hook.py
120
121
122
123
124
125
126
127
128
def openai_chatgpt_alerter_failure_hook(
    exception: BaseException,
) -> None:
    """Alerter hook that uses the OpenAI ChatGPT model.

    Args:
        exception: The exception that was raised.
    """
    openai_alerter_failure_hook_helper(exception, "gpt-3.5-turbo")
openai_gpt4_alerter_failure_hook(exception: BaseException) -> None

Alerter hook that uses the OpenAI GPT-4 model.

Parameters:

Name Type Description Default
exception BaseException

The exception that was raised.

required
Source code in src/zenml/integrations/openai/hooks/open_ai_failure_hook.py
131
132
133
134
135
136
137
138
139
def openai_gpt4_alerter_failure_hook(
    exception: BaseException,
) -> None:
    """Alerter hook that uses the OpenAI GPT-4 model.

    Args:
        exception: The exception that was raised.
    """
    openai_alerter_failure_hook_helper(exception, "gpt-4o")