Skip to content

Plotting

sunburst

sunburst

sunburst(graph: ConceptGraph, df: DataFrame, *, value: str = 'count', title: str | None = None, colorscale: str | None = None, color_value: str | None = None, color_by: ColorBy = 'auto', branch_palette: Sequence[str] | None = None, hide_root: bool = True, branchvalues: str = 'total', extra_hover: list[str] | None = None, hover_fmt: dict[str, str] | None = None, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render a sunburst from a ConceptGraph + a metric DataFrame.

PARAMETER DESCRIPTION
graph

The ConceptGraph to render.

TYPE: ConceptGraph

df

Tidy DataFrame produced by one of the metric functions. Must be indexed by path and contain value (and color_value if coloring is requested).

TYPE: DataFrame

value

Column used for sector size. Defaults to "count".

TYPE: str DEFAULT: 'count'

title

Figure title.

TYPE: str | None DEFAULT: None

colorscale

Plotly colorscale name (e.g. "Viridis", "Reds"). When set, sectors are colored by color_value (which defaults to value).

TYPE: str | None DEFAULT: None

color_value

Column used for color intensity. Defaults to value when colorscale is set.

TYPE: str | None DEFAULT: None

color_by

How to colour sectors. "auto" (default) picks "value" when a colorscale is given and "branch" otherwise. "branch" forces categorical-per-top-level-branch colouring (using branch_palette). "value" forces colorscale-based colouring (raises if colorscale is not given). "none" disables per-sector colour overrides (raw Plotly defaults).

TYPE: ColorBy DEFAULT: 'auto'

branch_palette

CSS color sequence used when colouring by branch. Defaults to the Plotly qualitative palette.

TYPE: Sequence[str] | None DEFAULT: None

hide_root

When True (default) the root concept is omitted and its direct children form the centre ring. Pass False to keep the legacy rendering with the root sector visible.

TYPE: bool DEFAULT: True

branchvalues

Plotly sunburst branchvalues ("total" or "remainder").

TYPE: str DEFAULT: 'total'

extra_hover

Additional columns to append to the hover tooltip.

TYPE: list[str] | None DEFAULT: None

hover_fmt

Per-column format spec strings (e.g. {"importance_sum": ".4f"}).

TYPE: dict[str, str] | None DEFAULT: None

layout_kwargs

Passed verbatim to fig.update_layout.

TYPE: dict[str, Any] | None DEFAULT: None

Source code in src/concept_graph_xai/plotting/sunburst.py
def sunburst(
    graph: ConceptGraph,
    df: pd.DataFrame,
    *,
    value: str = "count",
    title: str | None = None,
    colorscale: str | None = None,
    color_value: str | None = None,
    color_by: ColorBy = "auto",
    branch_palette: Sequence[str] | None = None,
    hide_root: bool = True,
    branchvalues: str = "total",
    extra_hover: list[str] | None = None,
    hover_fmt: dict[str, str] | None = None,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render a sunburst from a ConceptGraph + a metric DataFrame.

    Parameters
    ----------
    graph:
        The ConceptGraph to render.
    df:
        Tidy DataFrame produced by one of the metric functions. Must be
        indexed by ``path`` and contain ``value`` (and ``color_value`` if
        coloring is requested).
    value:
        Column used for sector size. Defaults to ``"count"``.
    title:
        Figure title.
    colorscale:
        Plotly colorscale name (e.g. ``"Viridis"``, ``"Reds"``). When set,
        sectors are colored by ``color_value`` (which defaults to ``value``).
    color_value:
        Column used for color intensity. Defaults to ``value`` when
        ``colorscale`` is set.
    color_by:
        How to colour sectors. ``"auto"`` (default) picks ``"value"`` when a
        ``colorscale`` is given and ``"branch"`` otherwise. ``"branch"``
        forces categorical-per-top-level-branch colouring (using
        ``branch_palette``). ``"value"`` forces colorscale-based colouring
        (raises if ``colorscale`` is not given). ``"none"`` disables per-sector
        colour overrides (raw Plotly defaults).
    branch_palette:
        CSS color sequence used when colouring by branch. Defaults to the
        Plotly qualitative palette.
    hide_root:
        When ``True`` (default) the root concept is omitted and its direct
        children form the centre ring. Pass ``False`` to keep the legacy
        rendering with the root sector visible.
    branchvalues:
        Plotly sunburst branchvalues (``"total"`` or ``"remainder"``).
    extra_hover:
        Additional columns to append to the hover tooltip.
    hover_fmt:
        Per-column ``format`` spec strings (e.g. ``{"importance_sum": ".4f"}``).
    layout_kwargs:
        Passed verbatim to ``fig.update_layout``.
    """

    arrays, ordered, sizes = sunburst_layout(graph, df, value=value, hide_root=hide_root)

    resolved = _resolve_color_by(color_by, colorscale)
    marker: dict[str, Any] = {}
    if resolved == "value":
        cv = color_value or value
        if cv not in ordered.columns:
            raise KeyError(f"color_value column {cv!r} not in DataFrame")
        cv_values = ordered[cv].fillna(0).to_numpy(dtype=float)
        cv_min = float(ordered[cv].min())
        cv_max = float(ordered[cv].max())
        marker.update(
            colors=cv_values,
            colorscale=colorscale,
            showscale=True,
            cmid=0 if (cv_min < 0 < cv_max) else None,
            colorbar={"title": cv},
        )
    elif resolved == "branch":
        marker["colors"] = branch_colors(graph, arrays["ids"], palette=branch_palette)

    hover_columns = [value]
    for col in ("kind", "feature_count", "used_feature_count", "is_used"):
        if col in ordered.columns and col not in hover_columns:
            hover_columns.append(col)
    if extra_hover:
        for col in extra_hover:
            if col not in hover_columns:
                hover_columns.append(col)

    hover = hover_text(ordered, hover_columns, fmt=hover_fmt)

    return build_sunburst_figure(
        arrays,
        sizes,
        marker=marker,
        hover=hover,
        title=title,
        branchvalues=branchvalues,
        layout_kwargs=layout_kwargs,
    )

utilization_map

utilization_map

utilization_map(graph: ConceptGraph, df: DataFrame, *, value: str = 'feature_count', used_color: str | None = None, unused_color: str = '#d3d3d3', branch_palette: Sequence[str] | None = None, hide_root: bool = True, title: str | None = None, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render a sunburst where unused branches are grey.

The DataFrame must be the output of :func:concept_graph_xai.metrics.utilization (it requires the is_used column). By default sector area encodes feature_count and colour encodes both branch identity (hue) and is-used status (grey when not used) — the chart subsumes the standalone sunburst(..., feature_counts(...)) structural view.

PARAMETER DESCRIPTION
used_color

If None (default), used sectors are coloured by their top-level branch with hierarchical shading (sub-concepts get lighter shades of the branch hue). Pass a CSS colour to fall back to a single solid colour for every used sector (legacy behaviour).

TYPE: str | None DEFAULT: None

branch_palette

Custom palette for branch base hues. Defaults to the Plotly qualitative palette.

TYPE: Sequence[str] | None DEFAULT: None

hide_root

When True (default) the root concept is omitted; pass False to keep the legacy root sector.

TYPE: bool DEFAULT: True

Source code in src/concept_graph_xai/plotting/utilization_map.py
def utilization_map(
    graph: ConceptGraph,
    df: pd.DataFrame,
    *,
    value: str = "feature_count",
    used_color: str | None = None,
    unused_color: str = "#d3d3d3",
    branch_palette: Sequence[str] | None = None,
    hide_root: bool = True,
    title: str | None = None,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render a sunburst where unused branches are grey.

    The DataFrame must be the output of
    :func:`concept_graph_xai.metrics.utilization` (it requires the ``is_used``
    column). By default sector area encodes ``feature_count`` and colour
    encodes both branch identity (hue) and is-used status (grey when not used)
    — the chart subsumes the standalone ``sunburst(..., feature_counts(...))``
    structural view.

    Parameters
    ----------
    used_color:
        If ``None`` (default), used sectors are coloured by their top-level
        branch with hierarchical shading (sub-concepts get lighter shades of
        the branch hue). Pass a CSS colour to fall back to a single solid
        colour for every used sector (legacy behaviour).
    branch_palette:
        Custom palette for branch base hues. Defaults to the Plotly
        qualitative palette.
    hide_root:
        When ``True`` (default) the root concept is omitted; pass ``False``
        to keep the legacy root sector.
    """

    if "is_used" not in df.columns:
        raise KeyError(
            "utilization_map expects DataFrame from metrics.utilization (no is_used col)"
        )

    arrays, ordered, sizes = sunburst_layout(graph, df, value=value, hide_root=hide_root)

    is_used = ordered["is_used"].to_numpy()
    if used_color is None:
        used_palette = branch_colors(graph, arrays["ids"], palette=branch_palette)
        colors = [used_palette[i] if bool(u) else unused_color for i, u in enumerate(is_used)]
    else:
        colors = [used_color if bool(u) else unused_color for u in is_used]

    hover_cols = [
        c
        for c in (value, "is_used", "used_feature_count", "feature_count", "importance_sum")
        if c in ordered.columns
    ]
    hover = hover_text(ordered, hover_cols, fmt={"importance_sum": ".4f"})

    return build_sunburst_figure(
        arrays,
        sizes,
        marker={"colors": colors},
        hover=hover,
        title=title or "Concept utilization (grey = unused)",
        layout_kwargs=layout_kwargs,
    )

auc_drop_map

auc_drop_map

auc_drop_map(graph: ConceptGraph, df: DataFrame, *, value: str = 'auc_drop_mean', size: str = 'feature_count', colorscale: str = 'Reds', hide_root: bool = True, title: str | None = None, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render a sunburst where each concept is colored by its AUC drop.

Sector area uses size (feature count by default), the colour intensity uses value (mean AUC drop by default). Set hide_root=False to keep the root sector visible.

Source code in src/concept_graph_xai/plotting/auc_drop_map.py
def auc_drop_map(
    graph: ConceptGraph,
    df: pd.DataFrame,
    *,
    value: str = "auc_drop_mean",
    size: str = "feature_count",
    colorscale: str = "Reds",
    hide_root: bool = True,
    title: str | None = None,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render a sunburst where each concept is colored by its AUC drop.

    Sector area uses ``size`` (feature count by default), the colour intensity
    uses ``value`` (mean AUC drop by default). Set ``hide_root=False`` to
    keep the root sector visible.
    """

    if value not in df.columns:
        raise KeyError(f"{value!r} not in DataFrame; run metrics.auc_drop first")
    if size not in df.columns:
        raise KeyError(f"{size!r} not in DataFrame")

    arrays, ordered, sizes = sunburst_layout(graph, df, value=size, hide_root=hide_root)

    drop_vals = ordered[value].to_numpy(dtype=float)
    drop_for_color = np.where(np.isnan(drop_vals), 0.0, drop_vals)
    cmax = float(np.nanmax(np.abs(drop_vals))) if not np.all(np.isnan(drop_vals)) else 1.0
    cmin = -cmax if (np.nanmin(drop_vals) < 0) else 0.0

    hover_cols = [
        c
        for c in (
            value,
            "auc_drop_std",
            "ablated_score_mean",
            "baseline_score",
            "feature_count",
            "strategy",
        )
        if c in ordered.columns
    ]
    hover = hover_text(
        ordered,
        hover_cols,
        fmt={
            value: "+.4f",
            "auc_drop_std": ".4f",
            "ablated_score_mean": ".4f",
            "baseline_score": ".4f",
        },
    )

    return build_sunburst_figure(
        arrays,
        sizes,
        marker={
            "colors": drop_for_color,
            "colorscale": colorscale,
            "cmin": cmin,
            "cmax": cmax,
            "showscale": True,
            "colorbar": {"title": value},
        },
        hover=hover,
        title=title or "AUC drop per concept",
        layout_kwargs=layout_kwargs,
    )

correlation_block

correlation_block

correlation_block(result: CorrelationResult, *, title: str | None = None, show_block_labels: bool = True, annotate_mean_abs: bool = True, colorscale: str = 'RdBu', zmid: float = 0.0, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render a correlation matrix with concept-block separators.

Works on the output of any of :func:feature_correlation, :func:nullity_correlation, or :func:shap_correlation — they all return a :class:CorrelationResult.

PARAMETER DESCRIPTION
result

Output of one of the correlation metrics.

TYPE: CorrelationResult

title

Figure title.

TYPE: str | None DEFAULT: None

show_block_labels

Draw the concept name above each diagonal block.

TYPE: bool DEFAULT: True

annotate_mean_abs

Print mean(|r|) inside each diagonal block.

TYPE: bool DEFAULT: True

colorscale

Plotly colorscale name. Default RdBu is symmetric around zero.

TYPE: str DEFAULT: 'RdBu'

zmid

Mid value for the colorscale. Use 0 for a diverging palette.

TYPE: float DEFAULT: 0.0

layout_kwargs

Passed verbatim to fig.update_layout.

TYPE: dict[str, Any] | None DEFAULT: None

Source code in src/concept_graph_xai/plotting/correlation_block.py
def correlation_block(
    result: CorrelationResult,
    *,
    title: str | None = None,
    show_block_labels: bool = True,
    annotate_mean_abs: bool = True,
    colorscale: str = "RdBu",
    zmid: float = 0.0,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render a correlation matrix with concept-block separators.

    Works on the output of any of :func:`feature_correlation`,
    :func:`nullity_correlation`, or :func:`shap_correlation` — they all return
    a :class:`CorrelationResult`.

    Parameters
    ----------
    result:
        Output of one of the correlation metrics.
    title:
        Figure title.
    show_block_labels:
        Draw the concept name above each diagonal block.
    annotate_mean_abs:
        Print ``mean(|r|)`` inside each diagonal block.
    colorscale:
        Plotly colorscale name. Default ``RdBu`` is symmetric around zero.
    zmid:
        Mid value for the colorscale. Use ``0`` for a diverging palette.
    layout_kwargs:
        Passed verbatim to ``fig.update_layout``.
    """

    matrix = result.matrix
    n = matrix.shape[0]

    fig = go.Figure(
        go.Heatmap(
            z=matrix.to_numpy(),
            x=list(matrix.columns),
            y=list(matrix.index),
            colorscale=colorscale,
            zmid=zmid,
            zmin=-1.0,
            zmax=1.0,
            colorbar={"title": f"{result.method} ρ"},
            hovertemplate="%{x} ↔ %{y}<br>%{z:.3f}<extra></extra>",
        )
    )

    shapes: list[dict[str, Any]] = []
    annotations: list[dict[str, Any]] = []
    stats_lookup = result.block_stats.set_index("concept_path").to_dict("index")

    # Depth = number of slashes in concept_path. Root has depth 0; top-level
    # concepts under root have depth 1; sub-concepts depth 2, etc. We stack
    # block labels in horizontal rows below the heatmap, with the deepest
    # concept closest to the heatmap and the top-level concepts furthest down.
    # This stops nested blocks (e.g. "Behaviour" + "Delinquency") from writing
    # their labels on top of each other.
    block_depths = [path.count("/") for path, _s, _e in result.blocks]
    visible_depths = [d for d in block_depths if d > 0]
    max_depth = max(visible_depths) if visible_depths else 1
    row_height = 1.0
    label_top_y = -1.2  # closest to heatmap (deepest concepts)

    for (path, start, end), depth in zip(result.blocks, block_depths, strict=True):
        # Diagonal block border
        shapes.append(
            {
                "type": "rect",
                "xref": "x",
                "yref": "y",
                "x0": start - 0.5,
                "x1": end - 0.5,
                "y0": start - 0.5,
                "y1": end - 0.5,
                "line": {"color": "black", "width": 1.5},
                "fillcolor": "rgba(0,0,0,0)",
            }
        )
        if show_block_labels and end - start >= 1 and depth >= 1:
            label = path.split("/")[-1]
            # Deeper concepts → smaller magnitude y (closer to heatmap).
            # Top-level branches (depth=1) → most negative y (further down).
            row = max_depth - depth  # 0 for deepest, max_depth-1 for top-level
            y_pos = label_top_y - row * row_height
            font_size = 12 if depth == 1 else max(8, 11 - (depth - 1))
            annotations.append(
                {
                    "x": (start + end - 1) / 2,
                    "y": y_pos,
                    "xref": "x",
                    "yref": "y",
                    "text": f"<b>{label}</b>" if depth == 1 else label,
                    "showarrow": False,
                    "font": {"size": font_size},
                }
            )
        if annotate_mean_abs and end - start >= 2:
            stats = stats_lookup.get(path, {})
            mean_abs = stats.get("mean_abs")
            if mean_abs is not None:
                annotations.append(
                    {
                        "x": (start + end - 1) / 2,
                        "y": (start + end - 1) / 2,
                        "xref": "x",
                        "yref": "y",
                        "text": f"|ρ̄|={mean_abs:.2f}",
                        "showarrow": False,
                        "font": {"size": 10, "color": "black"},
                        "bgcolor": "rgba(255,255,255,0.6)",
                    }
                )

    label_band = max(1, max_depth) * row_height + 0.5  # space reserved below heatmap
    bottom_margin = int(60 + 22 * max_depth)

    fig.update_layout(
        title=title,
        xaxis={"side": "bottom", "tickangle": 45, "showgrid": False, "range": [-0.5, n - 0.5]},
        yaxis={
            "autorange": "reversed",
            "showgrid": False,
            "range": [n - 0.5, label_top_y - label_band],
        },
        shapes=shapes,
        annotations=annotations,
        margin={"t": 40, "l": 40, "r": 40, "b": bottom_margin},
    )
    if layout_kwargs:
        fig.update_layout(**layout_kwargs)
    return fig

joint_missing_map

joint_missing_map

joint_missing_map(graph: ConceptGraph, df: DataFrame, *, value: str = 'joint_missing_rate', size: str = 'feature_count', colorscale: str = 'Reds', hide_root: bool = True, title: str | None = None, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render a sunburst where each concept is coloured by its joint-missing rate.

The DataFrame must come from :func:joint_missing_rate. Sector size uses feature_count so the shape matches the existing sunburst plots; colour intensity uses joint_missing_rate. Set hide_root=False to keep the root sector visible.

Source code in src/concept_graph_xai/plotting/joint_missing_map.py
def joint_missing_map(
    graph: ConceptGraph,
    df: pd.DataFrame,
    *,
    value: str = "joint_missing_rate",
    size: str = "feature_count",
    colorscale: str = "Reds",
    hide_root: bool = True,
    title: str | None = None,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render a sunburst where each concept is coloured by its joint-missing rate.

    The DataFrame must come from :func:`joint_missing_rate`. Sector size uses
    ``feature_count`` so the shape matches the existing sunburst plots; colour
    intensity uses ``joint_missing_rate``. Set ``hide_root=False`` to keep
    the root sector visible.
    """

    if value not in df.columns:
        raise KeyError(f"{value!r} not in DataFrame; run joint_missing_rate first")
    if size not in df.columns:
        raise KeyError(f"{size!r} not in DataFrame")

    arrays, ordered, sizes = sunburst_layout(graph, df, value=size, hide_root=hide_root)
    rates = ordered[value].fillna(0).to_numpy(dtype=float)
    cmax = max(1e-6, float(rates.max()))

    hover_cols = [c for c in (value, "feature_count") if c in ordered.columns]
    hover = hover_text(ordered, hover_cols, fmt={value: ".3f"})

    return build_sunburst_figure(
        arrays,
        sizes,
        marker={
            "colors": rates,
            "colorscale": colorscale,
            "cmin": 0.0,
            "cmax": cmax,
            "showscale": True,
            "colorbar": {"title": value},
        },
        hover=hover,
        title=title or "Joint missingness per concept",
        layout_kwargs=layout_kwargs,
    )

coherence_importance_scatter

coherence_importance_scatter

coherence_importance_scatter(df: DataFrame, *, only_concepts: bool = True, label_points: bool = True, title: str | None = None, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render the coherence × importance quadrant scatter.

PARAMETER DESCRIPTION
df

Output of :func:coherence_importance. Must carry coherence, importance_sum and quadrant columns. Threshold values are read from df.attrs["coherence_threshold"] and df.attrs["importance_threshold"].

TYPE: DataFrame

only_concepts

Drop rows where kind == "feature" so the chart shows only business concepts.

TYPE: bool DEFAULT: True

label_points

Annotate every point with the concept name.

TYPE: bool DEFAULT: True

Source code in src/concept_graph_xai/plotting/coherence_importance_scatter.py
def coherence_importance_scatter(
    df: pd.DataFrame,
    *,
    only_concepts: bool = True,
    label_points: bool = True,
    title: str | None = None,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render the coherence × importance quadrant scatter.

    Parameters
    ----------
    df:
        Output of :func:`coherence_importance`. Must carry ``coherence``,
        ``importance_sum`` and ``quadrant`` columns. Threshold values are
        read from ``df.attrs["coherence_threshold"]`` and
        ``df.attrs["importance_threshold"]``.
    only_concepts:
        Drop rows where ``kind == "feature"`` so the chart shows only
        business concepts.
    label_points:
        Annotate every point with the concept name.
    """

    needed = {"coherence", "importance_sum", "quadrant"}
    missing = needed - set(df.columns)
    if missing:
        raise KeyError(f"missing columns from coherence_importance: {missing}")

    plot_df = df.copy()
    if only_concepts and "kind" in plot_df.columns:
        plot_df = plot_df.loc[plot_df["kind"] == "concept"].copy()

    coh_thr = float(df.attrs.get("coherence_threshold", 0.0))
    imp_thr = float(df.attrs.get("importance_threshold", 0.0))

    fig = go.Figure()
    for quadrant, color in _QUADRANT_COLOR.items():
        sub = plot_df.loc[plot_df["quadrant"] == quadrant]
        if sub.empty:
            continue
        fig.add_trace(
            go.Scatter(
                x=sub["coherence"],
                y=sub["importance_sum"],
                mode="markers+text" if label_points else "markers",
                text=sub["name"] if label_points else None,
                textposition="top center",
                marker={
                    "size": 12,
                    "color": color,
                    "line": {"color": "black", "width": 0.5},
                },
                name=quadrant.replace("_", " "),
                customdata=np.stack(
                    [
                        sub.get("feature_count", pd.Series([0] * len(sub))).to_numpy(),
                        sub.get("kind", pd.Series([""] * len(sub))).to_numpy(),
                    ],
                    axis=1,
                ),
                hovertemplate=(
                    "<b>%{text}</b><br>"
                    "coherence: %{x:.3f}<br>"
                    "importance: %{y:.4f}<br>"
                    "feature_count: %{customdata[0]}<br>"
                    "kind: %{customdata[1]}"
                    "<extra></extra>"
                ),
            )
        )

    fig.add_hline(y=imp_thr, line={"color": "black", "dash": "dash", "width": 1})
    fig.add_vline(x=coh_thr, line={"color": "black", "dash": "dash", "width": 1})

    fig.update_layout(
        title=title or "Concept coherence vs importance",
        xaxis_title=f"within-concept mean(|ρ|)  ({df.attrs.get('method', 'spearman')})",
        yaxis_title="summed |SHAP|",
        margin={"t": 60, "l": 60, "r": 30, "b": 60},
    )
    if layout_kwargs:
        fig.update_layout(**layout_kwargs)
    return fig

regulatory_tag_overlay

regulatory_tag_overlay

regulatory_tag_overlay(graph: ConceptGraph, df: DataFrame | None = None, *, tag_key: str = 'tag', palette: dict[str, str] | None = None, untagged_color: str = '#dddddd', value: str = 'count', hide_root: bool = True, title: str | None = None, layout_kwargs: dict[str, Any] | None = None) -> Figure

Render a sunburst whose sectors are coloured by a node-metadata tag.

PARAMETER DESCRIPTION
graph

ConceptGraph; tag is read from graph.view(node).metadata[tag_key].

TYPE: ConceptGraph

df

Optional DataFrame providing the feature_count column (or any value column). Defaults to a count-based sunburst.

TYPE: DataFrame | None DEFAULT: None

tag_key

Metadata key carrying the categorical tag.

TYPE: str DEFAULT: 'tag'

palette

Optional tag -> css_color mapping. Unmapped tags get colours from a default palette.

TYPE: dict[str, str] | None DEFAULT: None

untagged_color

Colour for nodes that carry no value under tag_key.

TYPE: str DEFAULT: '#dddddd'

hide_root

When True (default) the root concept is omitted; pass False to keep the legacy root sector.

TYPE: bool DEFAULT: True

Source code in src/concept_graph_xai/plotting/regulatory_tag_overlay.py
def regulatory_tag_overlay(
    graph: ConceptGraph,
    df: pd.DataFrame | None = None,
    *,
    tag_key: str = "tag",
    palette: dict[str, str] | None = None,
    untagged_color: str = "#dddddd",
    value: str = "count",
    hide_root: bool = True,
    title: str | None = None,
    layout_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
    """Render a sunburst whose sectors are coloured by a node-metadata tag.

    Parameters
    ----------
    graph:
        ConceptGraph; tag is read from ``graph.view(node).metadata[tag_key]``.
    df:
        Optional DataFrame providing the ``feature_count`` column (or any
        ``value`` column). Defaults to a count-based sunburst.
    tag_key:
        Metadata key carrying the categorical tag.
    palette:
        Optional ``tag -> css_color`` mapping. Unmapped tags get colours from
        a default palette.
    untagged_color:
        Colour for nodes that carry no value under ``tag_key``.
    hide_root:
        When ``True`` (default) the root concept is omitted; pass ``False``
        to keep the legacy root sector.
    """

    if df is None:
        from concept_graph_xai.metrics.counts import feature_counts

        df = feature_counts(graph)

    arrays, ordered, sizes = sunburst_layout(graph, df, value=value, hide_root=hide_root)

    rendered_nodes = [
        node for node in graph.nodes_in_order() if not (hide_root and node == graph.root)
    ]
    tags: list[str] = []
    for node in rendered_nodes:
        meta = graph.view(node).metadata
        tag = meta.get(tag_key)
        tags.append(str(tag) if tag is not None else "")

    palette_map = dict(palette) if palette else {}
    next_idx = 0
    for tag in tags:
        if tag and tag not in palette_map:
            palette_map[tag] = _DEFAULT_PALETTE[next_idx % len(_DEFAULT_PALETTE)]
            next_idx += 1

    colors = [palette_map.get(tag, untagged_color) for tag in tags]
    hover = hover_text(ordered.assign(tag=tags), [value, "tag"])

    return build_sunburst_figure(
        arrays,
        sizes,
        marker={"colors": colors},
        hover=hover,
        title=title or f"Concepts coloured by {tag_key!r}",
        layout_kwargs=layout_kwargs,
    )