Skip to content

Ablation

auc_drop

auc_drop(graph: ConceptGraph, model: Any, X: DataFrame | ndarray, y: ndarray | Series, feature_names: Sequence[str] | None = None, *, strategy: Strategy = 'permutation', metric: ScorerLike = 'roc_auc', n_repeats: int = 10, random_state: int | None = 42, train_fn: Callable[[DataFrame | ndarray, ndarray], Any] | None = None, X_train: DataFrame | ndarray | None = None, y_train: ndarray | Series | None = None, shap_values: ndarray | None = None, base_predictions: ndarray | None = None, skip_root: bool = True) -> DataFrame

Compute concept-level metric drop under ablation.

Source code in src/concept_graph_xai/metrics/ablation.py
def auc_drop(
    graph: ConceptGraph,
    model: Any,
    X: pd.DataFrame | np.ndarray,
    y: np.ndarray | pd.Series,
    feature_names: Sequence[str] | None = None,
    *,
    strategy: Strategy = "permutation",
    metric: ScorerLike = "roc_auc",
    n_repeats: int = 10,
    random_state: int | None = 42,
    train_fn: Callable[[pd.DataFrame | np.ndarray, np.ndarray], Any] | None = None,
    X_train: pd.DataFrame | np.ndarray | None = None,
    y_train: np.ndarray | pd.Series | None = None,
    shap_values: np.ndarray | None = None,
    base_predictions: np.ndarray | None = None,
    skip_root: bool = True,
) -> pd.DataFrame:
    """Compute concept-level metric drop under ablation."""

    feats_in_X = _features_in_x(X, feature_names or list(graph.features()))
    y_arr = np.asarray(y)
    score = _resolve_scorer(metric)

    if strategy == "shap_marginal":
        df = _shap_marginal(graph, X, y_arr, score, shap_values, base_predictions, skip_root)
    elif strategy == "permutation":
        df = _permutation(
            graph, model, X, y_arr, feats_in_X, score, n_repeats, random_state, skip_root
        )
    elif strategy == "retrain":
        df = _retrain(
            graph,
            X,
            y_arr,
            feats_in_X,
            score,
            train_fn,
            X_train,
            y_train,
            skip_root,
        )
    else:
        raise ValueError(f"unknown strategy {strategy!r}")

    df["strategy"] = strategy
    return df