Source code for utils.rank_fragility.attribution

"""Advantage decomposition helpers for audited prediction tables."""

from __future__ import annotations

import numpy as np
import pandas as pd

from .metrics import per_sample_loss


def _similarity_bins(values: pd.Series) -> pd.Series:
    bins = [0.0, 0.4, 0.6, 0.8, 0.9, 1.0000001]
    labels = ["[0, 0.4)", "[0.4, 0.6)", "[0.6, 0.8)", "[0.8, 0.9)", "[0.9, 1.0]"]
    out = pd.cut(pd.to_numeric(values, errors="coerce"), bins=bins, labels=labels, right=False, include_lowest=True)
    return out.astype(object).where(out.notna(), "unknown").astype(str)


def _aggregate(per_example: pd.DataFrame, group_type: str, group_values: pd.Series, total_positive: float) -> list[dict]:
    rows: list[dict] = []
    tmp = per_example.assign(_group=group_values.astype(str))
    for group, frame in tmp.groupby("_group", dropna=False):
        positive = frame["advantage"].clip(lower=0).sum()
        rows.append(
            {
                "group_type": group_type,
                "group": str(group),
                "n": int(len(frame)),
                "total_advantage": float(frame["advantage"].sum()),
                "mean_advantage": float(frame["advantage"].mean()) if len(frame) else np.nan,
                "fraction_of_total_positive_advantage": float(positive / total_positive) if total_positive > 0 else np.nan,
            }
        )
    return rows


[docs] def compute_advantage_decomposition( pred_audit_df: pd.DataFrame, sota_model: str, baseline_model: str, task: str, loss: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Compute per-example SOTA advantage and aggregate by chemistry audit strata.""" required = {"model", "molecule_id", "y_true", "y_pred", "audit_group", "max_train_tanimoto"} missing = sorted(required - set(pred_audit_df.columns)) if missing: raise ValueError(f"pred_audit_df is missing required column(s): {', '.join(missing)}") sota = pred_audit_df[pred_audit_df["model"] == sota_model].copy() baseline = pred_audit_df[pred_audit_df["model"] == baseline_model].copy() if sota.empty: raise ValueError(f"sota_model not found in predictions: {sota_model}") if baseline.empty: raise ValueError(f"baseline_model not found in predictions: {baseline_model}") merged = sota[ ["molecule_id", "y_true", "y_pred", "audit_group", "max_train_tanimoto"] ].rename(columns={"y_pred": "sota_pred"}).merge( baseline[["molecule_id", "y_pred"]].rename(columns={"y_pred": "baseline_pred"}), on="molecule_id", how="inner", ) merged["loss_sota"] = per_sample_loss(merged["y_true"], merged["sota_pred"], task=task, loss=loss) merged["loss_baseline"] = per_sample_loss(merged["y_true"], merged["baseline_pred"], task=task, loss=loss) merged["advantage"] = merged["loss_baseline"] - merged["loss_sota"] per_example = merged[ [ "molecule_id", "audit_group", "y_true", "sota_pred", "baseline_pred", "loss_sota", "loss_baseline", "advantage", "max_train_tanimoto", ] ].copy() total_positive = float(per_example["advantage"].clip(lower=0).sum()) rows: list[dict] = [] rows.extend(_aggregate(per_example, "audit_group", per_example["audit_group"], total_positive)) rows.extend( _aggregate( per_example, "clean_vs_nonclean", np.where(per_example["audit_group"] == "audit_clean", "audit_clean", "nonclean"), total_positive, ) ) rows.extend(_aggregate(per_example, "max_train_tanimoto_bin", _similarity_bins(per_example["max_train_tanimoto"]), total_positive)) decomposition = pd.DataFrame( rows, columns=[ "group_type", "group", "n", "total_advantage", "mean_advantage", "fraction_of_total_positive_advantage", ], ) return per_example, decomposition