Source code for utils.rank_fragility.audit

"""Molecular audit annotations for rank-fragility analysis."""

from __future__ import annotations

import logging
from typing import Any

import numpy as np
import pandas as pd

from .chem import max_tanimoto_to_train, morgan_fingerprint, murcko_scaffold_smiles, standardize_smiles
from .config import AuditConfig

LOG = logging.getLogger(__name__)


def _require_columns(df: pd.DataFrame, columns: list[str]) -> None:
    missing = [col for col in columns if col not in df.columns]
    if missing:
        raise ValueError(f"missing required column(s): {', '.join(missing)}")


def _label_key(value: Any) -> Any:
    if value is None:
        return None
    try:
        if pd.isna(value):
            return None
    except TypeError:
        pass
    if isinstance(value, str):
        text = value.strip()
        try:
            f = float(text)
            return int(f) if f.is_integer() else f
        except ValueError:
            return text
    if isinstance(value, np.generic):
        value = value.item()
    return value


def _classification_conflict(values: pd.Series) -> tuple[bool, float]:
    labels = {_label_key(v) for v in values.tolist()}
    labels.discard(None)
    return len(labels) > 1, float(max(len(labels) - 1, 0))


def _regression_conflict(values: pd.Series, threshold: float) -> tuple[bool, float]:
    numeric = pd.to_numeric(values, errors="coerce").dropna()
    if numeric.empty:
        return False, 0.0
    severity = float(numeric.max() - numeric.min())
    return severity > float(threshold), severity


def _group_columns(df: pd.DataFrame) -> list[str]:
    cols = ["canonical_smiles"]
    if "dataset" in df.columns:
        cols.append("dataset")
    if "task" in df.columns:
        cols.append("task")
    return cols


def _assign_audit_group(row: pd.Series) -> str:
    if str(row.get("split", "")).lower() != "test":
        return "not_test"
    if bool(row.get("label_conflict", False)):
        return "label_conflict"
    if bool(row.get("exact_train_test_leak", False)):
        return "exact_leak"
    if bool(row.get("near_train_analogue", False)):
        return "near_train_analogue"
    if bool(row.get("same_scaffold_as_train", False)):
        return "same_scaffold"
    return "audit_clean"


[docs] def audit_dataset(df: pd.DataFrame, config: AuditConfig) -> pd.DataFrame: """Annotate dataset rows with chemistry and train-test audit flags.""" _require_columns(df, [config.id_col, config.smiles_col, config.label_col, config.split_col]) out = df.copy() out["split"] = out[config.split_col].astype(str).str.lower() out["audit_label"] = out[config.label_col] out["canonical_smiles"] = [standardize_smiles(smi) for smi in out[config.smiles_col].tolist()] out["valid_mol"] = out["canonical_smiles"].notna() out["murcko_scaffold"] = [ murcko_scaffold_smiles(smi) if smi is not None else None for smi in out["canonical_smiles"].tolist() ] canonical_counts = out["canonical_smiles"].value_counts(dropna=True).to_dict() out["duplicate_count_full_dataset"] = out["canonical_smiles"].map(canonical_counts).fillna(0).astype(int) out["label_conflict"] = False out["conflict_severity"] = 0.0 valid = out[out["canonical_smiles"].notna()] if not valid.empty: for _, idx in valid.groupby(_group_columns(valid), dropna=False).groups.items(): labels = out.loc[idx, config.label_col] if config.task == "classification": conflict, severity = _classification_conflict(labels) else: conflict, severity = _regression_conflict(labels, config.regression_conflict_threshold) out.loc[idx, "label_conflict"] = bool(conflict) out.loc[idx, "conflict_severity"] = float(severity) train = out[out["split"] == "train"] test_mask = out["split"] == "test" train_smiles = set(train["canonical_smiles"].dropna().tolist()) train_scaffolds = {s for s in train["murcko_scaffold"].dropna().tolist() if s} out["exact_train_test_leak"] = False out.loc[test_mask, "exact_train_test_leak"] = out.loc[test_mask, "canonical_smiles"].isin(train_smiles) out["max_train_tanimoto"] = np.nan if test_mask.any(): train_fps = [morgan_fingerprint(smi) for smi in train["canonical_smiles"].dropna().tolist()] test_indices = out.index[test_mask].tolist() test_fps = [morgan_fingerprint(smi) if smi is not None else None for smi in out.loc[test_indices, "canonical_smiles"]] out.loc[test_indices, "max_train_tanimoto"] = max_tanimoto_to_train(test_fps, train_fps) out["near_train_analogue"] = False out.loc[test_mask, "near_train_analogue"] = ( out.loc[test_mask, "max_train_tanimoto"] >= float(config.primary_near_leak_threshold) ).fillna(False) out["same_scaffold_as_train"] = False out.loc[test_mask, "same_scaffold_as_train"] = out.loc[test_mask, "murcko_scaffold"].isin(train_scaffolds) out["audit_group"] = out.apply(_assign_audit_group, axis=1) return out
def _count_fraction_rows(metric: str, count: int, denominator: int, group: str | None = None) -> dict[str, Any]: fraction = float(count / denominator) if denominator else np.nan return { "metric": metric, "group": group, "label": None, "count": int(count), "fraction": fraction, "value": float(count), }
[docs] def summarize_audit(audited_df: pd.DataFrame) -> pd.DataFrame: """Return a long-form audit summary table.""" rows: list[dict[str, Any]] = [] n_rows = len(audited_df) rows.append(_count_fraction_rows("rows_total", n_rows, n_rows)) split_counts = audited_df["split"].value_counts(dropna=False) for split in ["train", "valid", "test"]: rows.append(_count_fraction_rows(f"rows_{split}", int(split_counts.get(split, 0)), n_rows, split)) rows.append(_count_fraction_rows("valid_molecules", int(audited_df["valid_mol"].sum()), n_rows)) test = audited_df[audited_df["split"] == "test"].copy() n_test = len(test) rows.extend( [ _count_fraction_rows("test_label_conflicts", int(test["label_conflict"].sum()), n_test), _count_fraction_rows("exact_test_leaks", int(test["exact_train_test_leak"].sum()), n_test), _count_fraction_rows("near_train_analogues", int(test["near_train_analogue"].sum()), n_test), _count_fraction_rows("same_scaffold_test_molecules", int(test["same_scaffold_as_train"].sum()), n_test), _count_fraction_rows("audit_clean_test_molecules", int((test["audit_group"] == "audit_clean").sum()), n_test), ] ) if not test.empty: group_counts = test["audit_group"].value_counts(dropna=False) for group, count in group_counts.items(): rows.append(_count_fraction_rows("audit_group_count", int(count), n_test, str(group))) label_col = "audit_label" if "audit_label" in test.columns else ("y" if "y" in test.columns else None) if label_col in test.columns: for (group, label), count in test.groupby(["audit_group", label_col], dropna=False).size().items(): denom = int((test["audit_group"] == group).sum()) rows.append( { "metric": "label_distribution_by_audit_group", "group": str(group), "label": label, "count": int(count), "fraction": float(count / denom) if denom else np.nan, "value": float(count), } ) return pd.DataFrame(rows, columns=["metric", "group", "label", "count", "fraction", "value"])