Source code for utils.rank_fragility.panels

"""Counterfactual evaluation-panel sampling utilities."""

from __future__ import annotations

import logging
import warnings
from typing import Iterable

import numpy as np
import pandas as pd

from .config import PanelConfig

LOG = logging.getLogger(__name__)


def _warn(message: str, collected: list[str]) -> None:
    LOG.warning(message)
    collected.append(message)
    warnings.warn(message, RuntimeWarning, stacklevel=3)


def _rate_label(rate: float | str) -> str:
    if isinstance(rate, str):
        return rate
    return f"{float(rate):.6g}"


def _parse_rate(rate: float | str, observed: float) -> float:
    if isinstance(rate, str):
        if rate.lower() == "observed":
            return float(observed)
        return float(rate)
    return float(rate)


def _numeric_rate_for_output(rate: float | str) -> float | str:
    if isinstance(rate, str) and rate.lower() == "observed":
        return "observed"
    return float(rate)


def _strata(df: pd.DataFrame, task: str, label_col: str) -> pd.Series:
    if df.empty:
        return pd.Series(dtype=object)
    if label_col not in df.columns:
        return pd.Series(["all"] * len(df), index=df.index, dtype=object)
    if task == "classification":
        return df[label_col].astype(str).fillna("missing")
    y = pd.to_numeric(df[label_col], errors="coerce")
    try:
        return pd.qcut(y.rank(method="first"), q=min(5, max(1, y.notna().sum())), duplicates="drop").astype(str)
    except Exception:
        return pd.Series(["all"] * len(df), index=df.index, dtype=object)


def _allocate_counts(pool: pd.DataFrame, n: int, task: str, label_col: str) -> dict[object, int]:
    if n <= 0:
        return {}
    strata = _strata(pool, task, label_col)
    counts = strata.value_counts()
    raw = counts / counts.sum() * n
    alloc = np.floor(raw).astype(int)
    remainder = int(n - alloc.sum())
    if remainder > 0:
        order = (raw - alloc).sort_values(ascending=False).index.tolist()
        for key in order[:remainder]:
            alloc.loc[key] += 1
    alloc = pd.concat([alloc, counts], axis=1)
    alloc.columns = ["requested", "available"]
    alloc["requested"] = alloc[["requested", "available"]].min(axis=1)
    deficit = int(n - alloc["requested"].sum())
    if deficit > 0:
        capacity = (alloc["available"] - alloc["requested"]).sort_values(ascending=False)
        for key, room in capacity.items():
            if deficit <= 0:
                break
            add = min(int(room), deficit)
            alloc.loc[key, "requested"] += add
            deficit -= add
    return {key: int(value) for key, value in alloc["requested"].items() if int(value) > 0}


def _sample_pool(pool: pd.DataFrame, n: int, rng: np.random.Generator, task: str, label_col: str) -> pd.DataFrame:
    if n <= 0:
        return pool.iloc[0:0].copy()
    if len(pool) < n:
        raise ValueError(f"cannot sample {n} unique molecules from pool of size {len(pool)}")
    if n == len(pool):
        return pool.sample(frac=1.0, random_state=int(rng.integers(0, 2**31 - 1))).copy()

    strata = _strata(pool, task, label_col)
    work = pool.assign(_stratum=strata)
    counts = _allocate_counts(work, n, task, label_col)
    pieces = []
    used: set[int] = set()
    for stratum, count in counts.items():
        group = work[work["_stratum"] == stratum]
        take = min(count, len(group))
        if take:
            sample = group.sample(n=take, replace=False, random_state=int(rng.integers(0, 2**31 - 1)))
            pieces.append(sample)
            used.update(sample.index.tolist())
    selected = pd.concat(pieces, axis=0) if pieces else work.iloc[0:0].copy()
    deficit = n - len(selected)
    if deficit > 0:
        remaining = work.loc[~work.index.isin(used)]
        fill = remaining.sample(n=deficit, replace=False, random_state=int(rng.integers(0, 2**31 - 1)))
        selected = pd.concat([selected, fill], axis=0)
    return selected.drop(columns=["_stratum"], errors="ignore").sample(
        frac=1.0, random_state=int(rng.integers(0, 2**31 - 1))
    )


def _max_feasible_panel_size(pos_n: int, neg_n: int, rates: Iterable[float], total_cap: int) -> int:
    best = 0
    for n in range(int(total_cap), 0, -1):
        feasible = True
        for rate in rates:
            k = int(round(n * rate))
            if k > pos_n or (n - k) > neg_n:
                feasible = False
                break
        if feasible:
            best = n
            break
    return best


def _fixed_or_auto_size(config: PanelConfig, default: int) -> int:
    if isinstance(config.panel_size, str) and config.panel_size.lower() == "auto":
        return int(default)
    return int(config.panel_size)


def _panel_rows(panel_id: str, mode: str, target_rate, observed_rate: float, sample: pd.DataFrame, matched_mode=None):
    rows = []
    for mol_id in sample["molecule_id"].astype(str).tolist():
        rows.append(
            {
                "panel_id": panel_id,
                "mode": mode,
                "target_rate": target_rate,
                "observed_rate": float(observed_rate),
                "molecule_id": mol_id,
                "matched_mode": matched_mode,
            }
        )
    return rows


def _prepare_test_df(audited_test_df: pd.DataFrame, config: PanelConfig) -> pd.DataFrame:
    if config.id_col not in audited_test_df.columns:
        raise ValueError(f"audited_test_df is missing id column {config.id_col!r}")
    required = {"audit_group", "exact_train_test_leak", "near_train_analogue", "label_conflict"}
    missing = sorted(required - set(audited_test_df.columns))
    if missing:
        raise ValueError(f"audited_test_df is missing audit column(s): {', '.join(missing)}")
    out = audited_test_df.copy()
    out["molecule_id"] = out[config.id_col].astype(str)
    return out.drop_duplicates("molecule_id", keep="first").reset_index(drop=True)


def _sample_with_fallback(
    primary: pd.DataFrame,
    fallback: pd.DataFrame,
    n: int,
    rng: np.random.Generator,
    task: str,
    label_col: str,
) -> pd.DataFrame:
    if n <= len(primary):
        return _sample_pool(primary, n, rng, task, label_col)
    if len(primary) + len(fallback) < n:
        raise ValueError("not enough molecules in primary + fallback pools")
    primary_sample = _sample_pool(primary, len(primary), rng, task, label_col) if len(primary) else primary
    remaining_pool = fallback.loc[~fallback["molecule_id"].isin(primary_sample["molecule_id"])]
    fill = _sample_pool(remaining_pool, n - len(primary_sample), rng, task, label_col)
    return pd.concat([primary_sample, fill], axis=0)


[docs] def generate_counterfactual_panels(audited_test_df: pd.DataFrame, config: PanelConfig) -> pd.DataFrame: """Generate a long-form counterfactual panel manifest.""" test = _prepare_test_df(audited_test_df, config) warnings_list: list[str] = [] rng = np.random.default_rng(config.random_seed) rows: list[dict] = [] clean = test[test["audit_group"] == "audit_clean"].copy() n_clean = _fixed_or_auto_size(config, len(clean)) if n_clean > 0 and len(clean) >= n_clean: for i in range(config.n_panels): sample = _sample_pool(clean, n_clean, rng, config.task, config.label_col) rows.extend(_panel_rows(f"clean_reference_{i:04d}", "clean_reference", 0.0, 0.0, sample)) else: _warn("clean_reference is infeasible because there are no audit_clean test molecules", warnings_list) observed_size = _fixed_or_auto_size(config, len(test)) if observed_size <= len(test) and observed_size > 0: observed_nonclean = float((test["audit_group"] != "audit_clean").mean()) if len(test) else np.nan groups = [g for _, g in test.groupby("audit_group", sort=True)] group_probs = np.asarray([len(g) for g in groups], dtype=float) group_probs = group_probs / group_probs.sum() for i in range(config.n_panels): desired = np.floor(group_probs * observed_size).astype(int) while desired.sum() < observed_size: desired[int(rng.choice(np.arange(len(desired)), p=group_probs))] += 1 desired = np.minimum(desired, np.asarray([len(g) for g in groups])) deficit = observed_size - int(desired.sum()) if deficit > 0: all_sample = _sample_pool(test, observed_size, rng, config.task, config.label_col) else: parts = [ _sample_pool(group, int(k), rng, config.task, config.label_col) for group, k in zip(groups, desired) if int(k) > 0 ] all_sample = pd.concat(parts, axis=0) if parts else test.iloc[0:0] rows.extend( _panel_rows( f"observed_composition_{i:04d}", "observed_composition", "observed", observed_nonclean, all_sample, ) ) else: _warn("observed_composition is infeasible for the requested panel_size", warnings_list) leakage_base = test[~test["label_conflict"].astype(bool)].copy() shortcut = leakage_base[ leakage_base["exact_train_test_leak"].astype(bool) | leakage_base["near_train_analogue"].astype(bool) ].copy() leakage_clean = leakage_base[leakage_base["audit_group"] == "audit_clean"].copy() leakage_fallback = leakage_base[ (~leakage_base["molecule_id"].isin(shortcut["molecule_id"])) & (~leakage_base["molecule_id"].isin(leakage_clean["molecule_id"])) ].copy() observed_shortcut_rate = float(len(shortcut) / len(leakage_base)) if len(leakage_base) else 0.0 requested_leakage_rates = [_parse_rate(r, observed_shortcut_rate) for r in config.target_rates] leakage_size_auto = _max_feasible_panel_size( len(shortcut), len(leakage_clean) + len(leakage_fallback), requested_leakage_rates, len(leakage_base) ) leakage_size = _fixed_or_auto_size(config, leakage_size_auto) feasible_leakage_rates: list[tuple[float | str, float]] = [] for raw_rate in config.target_rates: rate = _parse_rate(raw_rate, observed_shortcut_rate) if not 0 <= rate <= 1: _warn(f"skipping leakage_curve target_rate={raw_rate}: rate must be between 0 and 1", warnings_list) continue k = int(round(leakage_size * rate)) if leakage_size <= 0 or k > len(shortcut) or leakage_size - k > len(leakage_clean) + len(leakage_fallback): _warn(f"skipping leakage_curve target_rate={raw_rate}: infeasible with available pools", warnings_list) continue feasible_leakage_rates.append((_numeric_rate_for_output(raw_rate), rate)) for i in range(config.n_panels): pos = _sample_pool(shortcut, k, rng, config.task, config.label_col) remaining = _sample_with_fallback(leakage_clean, leakage_fallback, leakage_size - k, rng, config.task, config.label_col) sample = pd.concat([pos, remaining], axis=0) observed = float( ( sample["exact_train_test_leak"].astype(bool) | sample["near_train_analogue"].astype(bool) ).mean() ) rows.extend( _panel_rows( f"leakage_curve_{_rate_label(raw_rate)}_{i:04d}", "leakage_curve", _numeric_rate_for_output(raw_rate), observed, sample, ) ) conflict = test[test["label_conflict"].astype(bool)].copy() conflict_bg_primary = test[(test["audit_group"] == "audit_clean") & (~test["exact_train_test_leak"].astype(bool))].copy() conflict_bg_fallback = test[ (~test["label_conflict"].astype(bool)) & (~test["molecule_id"].isin(conflict_bg_primary["molecule_id"])) & (~test["exact_train_test_leak"].astype(bool)) ].copy() if len(conflict_bg_primary) + len(conflict_bg_fallback) == 0: conflict_bg_fallback = test[~test["label_conflict"].astype(bool)].copy() conflict_universe = pd.concat([conflict, conflict_bg_primary, conflict_bg_fallback], axis=0).drop_duplicates("molecule_id") observed_conflict_rate = float(len(conflict) / len(test)) if len(test) else 0.0 requested_conflict_rates = [_parse_rate(r, observed_conflict_rate) for r in config.target_rates] conflict_size_auto = _max_feasible_panel_size( len(conflict), len(conflict_bg_primary) + len(conflict_bg_fallback), requested_conflict_rates, len(conflict_universe) ) conflict_size = _fixed_or_auto_size(config, conflict_size_auto) feasible_conflict_rates: list[tuple[float | str, float]] = [] for raw_rate in config.target_rates: rate = _parse_rate(raw_rate, observed_conflict_rate) if not 0 <= rate <= 1: _warn(f"skipping conflict_curve target_rate={raw_rate}: rate must be between 0 and 1", warnings_list) continue k = int(round(conflict_size * rate)) if conflict_size <= 0 or k > len(conflict) or conflict_size - k > len(conflict_bg_primary) + len(conflict_bg_fallback): _warn(f"skipping conflict_curve target_rate={raw_rate}: infeasible with available pools", warnings_list) continue feasible_conflict_rates.append((_numeric_rate_for_output(raw_rate), rate)) for i in range(config.n_panels): pos = _sample_pool(conflict, k, rng, config.task, config.label_col) remaining = _sample_with_fallback( conflict_bg_primary, conflict_bg_fallback, conflict_size - k, rng, config.task, config.label_col ) sample = pd.concat([pos, remaining], axis=0) observed = float(sample["label_conflict"].astype(bool).mean()) rows.extend( _panel_rows( f"conflict_curve_{_rate_label(raw_rate)}_{i:04d}", "conflict_curve", _numeric_rate_for_output(raw_rate), observed, sample, ) ) for matched_mode, feasible_rates, size in [ ("leakage_curve", feasible_leakage_rates, leakage_size), ("conflict_curve", feasible_conflict_rates, conflict_size), ]: if size <= 0 or len(test) < size: continue for raw_rate, _ in feasible_rates: for i in range(config.n_panels): sample = _sample_pool(test, size, rng, config.task, config.label_col) observed = float((sample["audit_group"] != "audit_clean").mean()) rows.extend( _panel_rows( f"random_matched_control_{matched_mode}_{_rate_label(raw_rate)}_{i:04d}", "random_matched_control", raw_rate, observed, sample, matched_mode=matched_mode, ) ) manifest = pd.DataFrame( rows, columns=["panel_id", "mode", "target_rate", "observed_rate", "molecule_id", "matched_mode"], ) if manifest.empty: _warn("no counterfactual panels were generated", warnings_list) manifest.attrs["warnings"] = warnings_list return manifest