Skip to content

Plotting

plotting

Visualization utilities for SWIFT monitoring results.

Provides two public plotting functions:

  • plot_bucket_profile — SHAP response curve + density per bucket for a feature.
  • plot_feature_swift_scores — Bar chart of SWIFT scores per feature.

Both return (Figure, Axes) tuples for further customization.

plot_bucket_profile

plot_bucket_profile(bucket_set: BucketSet, feature_values: ndarray, shap_values: ndarray, compare_values: ndarray | None = None, primary_values: ndarray | None = None, labels: tuple[str, str] = ('Reference', 'Comparison'), figsize: tuple[float, float] = (10, 5), title: str | None = None, max_label_buckets: int = 20, x_axis: str = 'bucket') -> tuple[Figure, Axes]

Plot the bucketing profile for a single feature.

Shows mean SHAP per bucket (line + error band) and observation density (filled line). Optionally overlays a comparison sample's density.

PARAMETER DESCRIPTION
bucket_set

Fitted bucket set for the feature.

TYPE: BucketSet

feature_values

Raw feature values from the reference sample (used for the SHAP curve and, when primary_values is None, for the primary density).

TYPE: ndarray

shap_values

SHAP values for this feature on the reference sample.

TYPE: ndarray

compare_values

Raw feature values from a comparison sample. If provided, shows a second density line.

TYPE: ndarray or None DEFAULT: None

primary_values

If provided, these values are used for the primary density instead of feature_values. Useful for showing density of an arbitrary sample while the SHAP curve stays anchored to the reference.

TYPE: ndarray or None DEFAULT: None

labels

Legend labels for (primary, comparison) densities.

TYPE: tuple of str DEFAULT: ('Reference', 'Comparison')

figsize

Figure size.

TYPE: tuple DEFAULT: (10, 5)

title

Custom title. Defaults to "Bucketing Profile: {feature_name}".

TYPE: str or None DEFAULT: None

max_label_buckets

Switch from interval notation to bucket indices if exceeded.

TYPE: int DEFAULT: 20

x_axis

"bucket" (default) uses integer bucket indices on the x-axis. "natural" uses actual feature-value positions (bucket midpoints, data min/max for edge buckets, NULL placed at the left with a gap).

TYPE: ('bucket', 'natural') DEFAULT: "bucket"

RETURNS DESCRIPTION
(Figure, Axes)
Source code in src/swift/plotting.py
def plot_bucket_profile(
    bucket_set: BucketSet,
    feature_values: np.ndarray,
    shap_values: np.ndarray,
    compare_values: np.ndarray | None = None,
    primary_values: np.ndarray | None = None,
    labels: tuple[str, str] = ("Reference", "Comparison"),
    figsize: tuple[float, float] = (10, 5),
    title: str | None = None,
    max_label_buckets: int = 20,
    x_axis: str = "bucket",
) -> tuple[Figure, Axes]:
    """Plot the bucketing profile for a single feature.

    Shows mean SHAP per bucket (line + error band) and observation density
    (filled line).  Optionally overlays a comparison sample's density.

    Parameters
    ----------
    bucket_set : BucketSet
        Fitted bucket set for the feature.
    feature_values : np.ndarray
        Raw feature values from the reference sample (used for the SHAP
        curve and, when *primary_values* is ``None``, for the primary
        density).
    shap_values : np.ndarray
        SHAP values for this feature on the reference sample.
    compare_values : np.ndarray or None
        Raw feature values from a comparison sample.  If provided, shows
        a second density line.
    primary_values : np.ndarray or None
        If provided, these values are used for the primary density
        instead of *feature_values*.  Useful for showing density of an
        arbitrary sample while the SHAP curve stays anchored to the
        reference.
    labels : tuple of str
        Legend labels for (primary, comparison) densities.
    figsize : tuple, default (10, 5)
        Figure size.
    title : str or None
        Custom title.  Defaults to ``"Bucketing Profile: {feature_name}"``.
    max_label_buckets : int, default 20
        Switch from interval notation to bucket indices if exceeded.
    x_axis : {"bucket", "natural"}
        ``"bucket"`` (default) uses integer bucket indices on the x-axis.
        ``"natural"`` uses actual feature-value positions (bucket
        midpoints, data min/max for edge buckets, NULL placed at the
        left with a gap).

    Returns
    -------
    (Figure, Axes)
    """
    sns.set_theme(style="whitegrid")
    palette = sns.color_palette()

    # Determine which values to use for primary density
    density_values = primary_values if primary_values is not None else feature_values

    stats = _compute_bucket_stats(bucket_set, feature_values, shap_values)
    primary_densities = _compute_sample_densities(bucket_set, density_values)
    bucket_labels = _format_bucket_labels(bucket_set, max_label_buckets)
    num_buckets = bucket_set.num_buckets

    # -- X positions --
    use_natural = x_axis == "natural"
    if use_natural:
        x_positions = _compute_natural_x_positions(bucket_set, feature_values)
    else:
        x_positions = np.arange(num_buckets, dtype=float)

    fig, ax_shap = plt.subplots(figsize=figsize)
    ax_density = ax_shap.twinx()

    # -- SHAP = 0 reference line (visible behind everything) --
    ax_shap.axhline(
        y=0,
        color="#888888",
        linewidth=1.4,
        linestyle="--",
        alpha=0.7,
        zorder=1.5,
        label="_nolegend_",
    )

    # -- Density lines + filled area (right y-axis) --
    comparison_mode = compare_values is not None

    # Primary density
    ax_density.plot(
        x_positions,
        primary_densities,
        color=palette[0],
        linewidth=1.5,
        alpha=0.8,
        zorder=1,
        label=labels[0] if comparison_mode else "Observation density",
    )
    ax_density.fill_between(
        x_positions,
        0,
        primary_densities,
        color=palette[0],
        alpha=0.15,
        zorder=0.9,
    )

    if comparison_mode:
        compare_densities = _compute_sample_densities(
            bucket_set, compare_values,
        )
        ax_density.plot(
            x_positions,
            compare_densities,
            color=palette[1],
            linewidth=1.5,
            alpha=0.8,
            zorder=1,
            label=labels[1],
        )
        ax_density.fill_between(
            x_positions,
            0,
            compare_densities,
            color=palette[1],
            alpha=0.15,
            zorder=0.9,
        )

    ax_density.set_ylabel("Density")

    # -- SHAP line (left y-axis) --
    shap_color = palette[2]
    mean_shaps = stats["mean_shaps"]
    shap_stds = stats["shap_stds"]

    ax_shap.plot(
        x_positions,
        mean_shaps,
        color=shap_color,
        marker="o",
        linewidth=2,
        markersize=6,
        zorder=3,
        label="Mean SHAP \u00b1 2 std",
    )

    # Error band (only where count > 0)
    upper = mean_shaps + 2 * shap_stds
    lower = mean_shaps - 2 * shap_stds
    mask_nonzero = stats["counts"] > 0
    ax_shap.fill_between(
        x_positions,
        lower,
        upper,
        where=mask_nonzero,
        alpha=0.2,
        color=shap_color,
        zorder=2,
    )

    ax_shap.set_ylabel("SHAP Value")

    # -- X-axis --
    if use_natural:
        # Natural axis: let matplotlib auto-format, but annotate NULL
        ax_shap.set_xlabel(bucket_set.feature_name)

        # Add a vertical dotted line to separate NULL from numeric range
        if num_buckets > 1:
            sep_x = (x_positions[0] + x_positions[1]) / 2
            ax_shap.axvline(
                sep_x,
                color="#aaaaaa",
                linewidth=1.0,
                linestyle=":",
                alpha=0.6,
                zorder=0.5,
            )
            # Label the NULL point
            ax_shap.annotate(
                "NULL",
                xy=(x_positions[0], 0),
                xytext=(x_positions[0], 0),
                fontsize=8,
                ha="center",
                va="bottom",
                color="#666666",
            )

        # Use auto-ticks for the numeric part, but add NULL tick
        # We set minor ticks off and let matplotlib handle major ticks
        ax_shap.xaxis.set_major_locator(mticker.AutoLocator())
    else:
        # Bucket-index mode: explicit tick labels
        ax_shap.set_xlabel("Bucket")
        ax_shap.set_xticks(x_positions)

        # Adaptive label formatting to avoid overlap
        if num_buckets > 10:
            fontsize = 7
            rotation = 90
        elif num_buckets > 3:
            fontsize = 8
            rotation = 60
        else:
            fontsize = 9
            rotation = 0

        ax_shap.set_xticklabels(
            bucket_labels,
            rotation=rotation,
            ha="right" if rotation > 0 else "center",
            fontsize=fontsize,
        )

    # -- Title --
    plot_title = title or f"Bucketing Profile: {bucket_set.feature_name}"
    ax_shap.set_title(plot_title)

    # -- Legend (combine both axes) --
    lines_shap, labels_shap = ax_shap.get_legend_handles_labels()
    lines_density, labels_density = ax_density.get_legend_handles_labels()
    ax_shap.legend(
        lines_shap + lines_density,
        labels_shap + labels_density,
        loc="upper right",
    )

    # Ensure SHAP line is drawn on top of density
    ax_shap.set_zorder(ax_density.get_zorder() + 1)
    ax_shap.patch.set_visible(False)

    fig.tight_layout()
    return fig, ax_shap

plot_feature_swift_scores

plot_feature_swift_scores(result: SWIFTResult, result_compare: SWIFTResult | None = None, labels: tuple[str, str] = ('Result A', 'Result B'), threshold: float | None = None, sort_by: str = 'score', feature_order: list[str] | None = None, figsize: tuple[float, float] = (12, 5), title: str | None = None) -> tuple[Figure, Axes]

Plot SWIFT scores per feature as a bar chart with reference lines.

Optionally compare two SWIFTResult objects side by side.

PARAMETER DESCRIPTION
result

Primary result from SWIFTMonitor.test().

TYPE: SWIFTResult

result_compare

Optional second result for side-by-side comparison.

TYPE: SWIFTResult or None DEFAULT: None

labels

Legend labels for the two results in comparison mode.

TYPE: tuple of str DEFAULT: ('Result A', 'Result B')

threshold

Optional detection threshold horizontal line.

TYPE: float or None DEFAULT: None

sort_by

Feature ordering on x-axis. "original" uses feature_order.

TYPE: ('score', 'name', 'original') DEFAULT: "score"

feature_order

Original feature order (from monitor.feature_names_in_). Required when sort_by="original".

TYPE: list[str] or None DEFAULT: None

figsize

Figure size.

TYPE: tuple DEFAULT: (12, 5)

title

Custom title.

TYPE: str or None DEFAULT: None

RETURNS DESCRIPTION
(Figure, Axes)
Source code in src/swift/plotting.py
def plot_feature_swift_scores(
    result: SWIFTResult,
    result_compare: SWIFTResult | None = None,
    labels: tuple[str, str] = ("Result A", "Result B"),
    threshold: float | None = None,
    sort_by: str = "score",
    feature_order: list[str] | None = None,
    figsize: tuple[float, float] = (12, 5),
    title: str | None = None,
) -> tuple[Figure, Axes]:
    """Plot SWIFT scores per feature as a bar chart with reference lines.

    Optionally compare two ``SWIFTResult`` objects side by side.

    Parameters
    ----------
    result : SWIFTResult
        Primary result from ``SWIFTMonitor.test()``.
    result_compare : SWIFTResult or None
        Optional second result for side-by-side comparison.
    labels : tuple of str
        Legend labels for the two results in comparison mode.
    threshold : float or None
        Optional detection threshold horizontal line.
    sort_by : {"score", "name", "original"}
        Feature ordering on x-axis.  ``"original"`` uses *feature_order*.
    feature_order : list[str] or None
        Original feature order (from ``monitor.feature_names_in_``).
        Required when ``sort_by="original"``.
    figsize : tuple, default (12, 5)
        Figure size.
    title : str or None
        Custom title.

    Returns
    -------
    (Figure, Axes)
    """
    sns.set_theme(style="whitegrid")

    comparison_mode = result_compare is not None

    # Build dicts: feature_name -> score / is_drifted
    scores_a = {fr.feature_name: fr.swift_score for fr in result.feature_results}
    drifted_a = {fr.feature_name: fr.is_drifted for fr in result.feature_results}

    if comparison_mode:
        scores_b = {
            fr.feature_name: fr.swift_score
            for fr in result_compare.feature_results
        }

    # Determine feature order
    feature_names = list(scores_a.keys())
    if sort_by == "score":
        feature_names = sorted(feature_names, key=lambda f: scores_a[f], reverse=True)
    elif sort_by == "name":
        feature_names = sorted(feature_names)
    elif sort_by == "original" and feature_order is not None:
        feature_names = [f for f in feature_order if f in scores_a]
    # else: keep dict order

    n_features = len(feature_names)
    x_positions = np.arange(n_features)

    fig, ax = plt.subplots(figsize=figsize)

    if comparison_mode:
        # -- Comparison mode: grouped bars --
        palette = sns.color_palette()
        bar_width = 0.35

        vals_a = [scores_a[f] for f in feature_names]
        vals_b = [scores_b.get(f, 0.0) for f in feature_names]

        ax.bar(
            x_positions - bar_width / 2,
            vals_a,
            width=bar_width,
            color=palette[0],
            edgecolor=_darken(palette[0]),
            label=labels[0],
        )
        ax.bar(
            x_positions + bar_width / 2,
            vals_b,
            width=bar_width,
            color=palette[1],
            edgecolor=_darken(palette[1]),
            label=labels[1],
        )

        # Only threshold line in comparison mode
        if threshold is not None:
            ax.axhline(
                threshold,
                linestyle=":",
                color="black",
                linewidth=1.5,
                label=f"Threshold = {threshold:.4f}",
            )

        plot_title = title or "SWIFT Scores Comparison"

    else:
        # -- Single result mode: drift-colored bars --
        color_drifted = "#e74c3c"
        color_ok = "#3498db"
        edge_drifted = _darken_hex(color_drifted)
        edge_ok = _darken_hex(color_ok)

        vals = [scores_a[f] for f in feature_names]
        colors = [
            color_drifted if drifted_a.get(f) else color_ok
            for f in feature_names
        ]
        edges = [
            edge_drifted if drifted_a.get(f) else edge_ok
            for f in feature_names
        ]

        bars = ax.bar(
            x_positions,
            vals,
            color=colors,
            edgecolor=edges,
            width=0.6,
        )

        # Proxy artists for legend
        from matplotlib.patches import Patch

        legend_elements = [
            Patch(facecolor=color_drifted, edgecolor=edge_drifted, label="Drifted"),
            Patch(facecolor=color_ok, edgecolor=edge_ok, label="Not drifted"),
        ]

        # Horizontal lines
        ax.axhline(
            result.swift_max,
            linestyle="--",
            color=color_drifted,
            linewidth=1.2,
            label=f"SWIFT max = {result.swift_max:.4f}",
        )
        ax.axhline(
            result.swift_mean,
            linestyle="--",
            color=color_ok,
            linewidth=1.2,
            label=f"SWIFT mean = {result.swift_mean:.4f}",
        )

        if threshold is not None:
            ax.axhline(
                threshold,
                linestyle=":",
                color="black",
                linewidth=1.5,
                label=f"Threshold = {threshold:.4f}",
            )

        legend_elements.extend(ax.get_legend_handles_labels()[0])
        # Reset and rebuild legend
        ax.legend(handles=legend_elements, loc="upper right")

        plot_title = title or "SWIFT Scores per Feature"

    # -- Axes --
    ax.set_xticks(x_positions)
    ax.set_xticklabels(
        feature_names,
        rotation=45 if n_features > 5 else 0,
        ha="right" if n_features > 5 else "center",
    )
    ax.set_ylabel("SWIFT Score")
    ax.set_title(plot_title)

    if comparison_mode:
        ax.legend(loc="upper right")

    fig.tight_layout()
    return fig, ax