Skip to content

Hooks

zenml.hooks special

The hooks package exposes some standard hooks that can be used in ZenML.

Hooks are functions that run after a step has exited.

alerter_hooks

Functionality for standard hooks.

alerter_failure_hook(context, params, exception)

Standard failure hook that executes after step fails.

This hook uses any BaseAlerter that is configured within the active stack to post a message.

Parameters:

Name Type Description Default
context StepContext

Context of the step.

required
params BaseParameters

Parameters used in the step.

required
exception BaseException

Original exception that lead to step failing.

required
Source code in zenml/hooks/alerter_hooks.py
def alerter_failure_hook(
    context: StepContext, params: BaseParameters, exception: BaseException
) -> None:
    """Standard failure hook that executes after step fails.

    This hook uses any `BaseAlerter` that is configured within the active stack to post a message.

    Args:
        context: Context of the step.
        params: Parameters used in the step.
        exception: Original exception that lead to step failing.
    """
    if context.stack and context.stack.alerter:
        output_captured = io.StringIO()
        original_stdout = sys.stdout
        sys.stdout = output_captured
        console = Console()
        console.print_exception(show_locals=False)

        sys.stdout = original_stdout
        rich_traceback = output_captured.getvalue()

        message = "*Failure Hook Notification! Step failed!*" + "\n\n"
        message += f"Pipeline name: `{context.pipeline_name}`" + "\n"
        message += f"Run name: `{context.run_name}`" + "\n"
        message += f"Step name: `{context.step_name}`" + "\n"
        message += f"Parameters: `{params}`" + "\n"
        message += (
            f"Exception: `({type(exception)}) {rich_traceback}`" + "\n\n"
        )
        message += (
            f"Step Cache Enabled: `{'True' if context.cache_enabled else 'False'}`"
            + "\n"
        )
        context.stack.alerter.post(message)
    else:
        logger.warning(
            "Specified standard failure hook but no alerter configured in the stack. Skipping.."
        )

alerter_success_hook(context, params)

Standard success hook that executes after step finishes successfully.

This hook uses any BaseAlerter that is configured within the active stack to post a message.

Parameters:

Name Type Description Default
context StepContext

Context of the step.

required
params BaseParameters

Parameters used in the step.

required
Source code in zenml/hooks/alerter_hooks.py
def alerter_success_hook(context: StepContext, params: BaseParameters) -> None:
    """Standard success hook that executes after step finishes successfully.

    This hook uses any `BaseAlerter` that is configured within the active stack to post a message.

    Args:
        context: Context of the step.
        params: Parameters used in the step.
    """
    if context.stack and context.stack.alerter:
        message = (
            "*Success Hook Notification! Step completed successfully*" + "\n\n"
        )
        message += f"Pipeline name: `{context.pipeline_name}`" + "\n"
        message += f"Run name: `{context.run_name}`" + "\n"
        message += f"Step name: `{context.step_name}`" + "\n"
        message += f"Parameters: `{params}`" + "\n"
        message += (
            f"Step Cache Enabled: `{'True' if context.cache_enabled else 'False'}`"
            + "\n"
        )
        context.stack.alerter.post(message)
    else:
        logger.warning(
            "Specified standard success hook but no alerter configured in the stack. Skipping.."
        )

hook_validators

Validation functions for hooks.

resolve_and_validate_hook(hook)

Resolves and validates a hook callback.

Parameters:

Name Type Description Default
hook HookSpecification

Hook function or source.

required

Returns:

Type Description
Source

Hook source.

Exceptions:

Type Description
ValueError

If hook_func is not a valid callable.

Source code in zenml/hooks/hook_validators.py
def resolve_and_validate_hook(hook: "HookSpecification") -> Source:
    """Resolves and validates a hook callback.

    Args:
        hook: Hook function or source.

    Returns:
        Hook source.

    Raises:
        ValueError: If `hook_func` is not a valid callable.
    """
    if isinstance(hook, (str, Source)):
        func = source_utils.load(hook)
    else:
        func = hook

    if not callable(func):
        raise ValueError(f"{func} is not a valid function.")

    from zenml.steps.base_parameters import BaseParameters
    from zenml.steps.step_context import StepContext

    sig = inspect.getfullargspec(inspect.unwrap(func))
    sig_annotations = sig.annotations
    if "return" in sig_annotations:
        sig_annotations.pop("return")

    if sig.args and len(sig.args) != len(sig_annotations):
        raise ValueError(
            "If you pass args to a hook, you must annotate them with one "
            "of the following types: `BaseException`, `BaseParameters`, "
            "and/or `StepContext`."
        )

    if sig_annotations:
        annotations = sig_annotations.values()
        seen_annotations = set()
        for annotation in annotations:
            if annotation:
                if annotation not in (
                    BaseException,
                    BaseParameters,
                    StepContext,
                ):
                    raise ValueError(
                        "Hook parameters must be of type `BaseException`, `BaseParameters`, "
                        f"and/or `StepContext`, not {annotation}"
                    )

                if annotation in seen_annotations:
                    raise ValueError(
                        "It looks like your hook function accepts more than of the "
                        "same argument annotation type. Please ensure you pass exactly "
                        "one of the following: `BaseException`, `BaseParameters`, "
                        "and/or `StepContext`. Currently your function has "
                        f"the following annotations: {sig_annotations}"
                    )
                seen_annotations.add(annotation)

    return source_utils.resolve(func)