Source code for utils.rank_fragility.predictions
"""Prediction loading and audit-merge helpers."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
LOG = logging.getLogger(__name__)
PREDICTION_COLUMNS = {"molecule_id", "y_true", "y_pred"}
[docs]
def load_prediction_directory(pred_dir: str) -> pd.DataFrame:
"""Load one prediction CSV per model into long format."""
root = Path(pred_dir)
if not root.exists() or not root.is_dir():
raise FileNotFoundError(f"prediction directory not found: {pred_dir}")
frames: list[pd.DataFrame] = []
for path in sorted(root.glob("*.csv")):
frame = pd.read_csv(path)
missing = sorted(PREDICTION_COLUMNS - set(frame.columns))
if missing:
raise ValueError(f"{path} is missing required prediction column(s): {', '.join(missing)}")
out = frame[["molecule_id", "y_true", "y_pred"]].copy()
out["model"] = path.stem
out["molecule_id"] = out["molecule_id"].astype(str)
if out["molecule_id"].duplicated().any():
dupes = out.loc[out["molecule_id"].duplicated(), "molecule_id"].head(5).tolist()
raise ValueError(f"{path} contains duplicate molecule_id values, e.g. {dupes}")
frames.append(out[["model", "molecule_id", "y_true", "y_pred"]])
if not frames:
raise ValueError(f"no prediction CSV files found in {pred_dir}")
return pd.concat(frames, ignore_index=True)
def _labels_agree(pred: pd.Series, dataset: pd.Series, task: str) -> np.ndarray:
if task == "regression":
return np.isclose(
pd.to_numeric(pred, errors="coerce"),
pd.to_numeric(dataset, errors="coerce"),
rtol=1e-6,
atol=1e-8,
equal_nan=True,
)
return pred.map(_label_key).to_numpy() == dataset.map(_label_key).to_numpy()
def _label_key(value: Any) -> Any:
try:
if pd.isna(value):
return None
except TypeError:
pass
if isinstance(value, str):
text = value.strip()
try:
f = float(text)
return int(f) if f.is_integer() else f
except ValueError:
return text
if isinstance(value, np.generic):
return value.item()
return value
[docs]
def merge_predictions_with_audit(pred_df: pd.DataFrame, audited_df: pd.DataFrame, config) -> pd.DataFrame:
"""Validate prediction labels and attach audit annotations for test rows."""
for column in PREDICTION_COLUMNS | {"model"}:
if column not in pred_df.columns:
raise ValueError(f"prediction dataframe is missing required column: {column}")
test = audited_df[audited_df["split"].astype(str).str.lower() == "test"].copy()
if test.empty:
raise ValueError("audited dataset contains no test rows")
if config.id_col not in test.columns or config.label_col not in test.columns:
raise ValueError(f"audited dataframe must contain {config.id_col!r} and {config.label_col!r}")
test["molecule_id"] = test[config.id_col].astype(str)
test = test.drop_duplicates("molecule_id", keep="first")
test_labels = test[["molecule_id", config.label_col]].rename(columns={config.label_col: "dataset_y_true"})
merged_for_validation = pred_df.merge(test_labels, on="molecule_id", how="inner")
if merged_for_validation.empty:
raise ValueError("no prediction molecule_id values overlap test rows")
agree = _labels_agree(merged_for_validation["y_true"], merged_for_validation["dataset_y_true"], config.task)
if not bool(np.all(agree)):
bad = merged_for_validation.loc[~agree, ["model", "molecule_id", "y_true", "dataset_y_true"]].head(10)
raise ValueError(f"prediction y_true does not agree with dataset labels:\n{bad.to_string(index=False)}")
all_test_ids = set(test["molecule_id"].tolist())
model_to_ids = {model: set(group["molecule_id"].tolist()) for model, group in pred_df.groupby("model")}
common_ids = set(all_test_ids)
warnings_list: list[str] = []
for model, ids in model_to_ids.items():
missing = sorted(all_test_ids - ids)
if missing:
message = f"model {model} is missing {len(missing)} test prediction(s); dropping them for all models"
LOG.warning(message)
warnings_list.append(message)
common_ids &= ids
if not common_ids:
raise ValueError("no common predicted test molecules across all models")
dropped = len(all_test_ids - common_ids)
if dropped:
message = f"dropping {dropped} test molecule(s) absent from at least one model"
LOG.warning(message)
warnings_list.append(message)
annotations = test.copy()
annotations["molecule_id"] = annotations[config.id_col].astype(str)
out = pred_df[pred_df["molecule_id"].isin(common_ids)].merge(annotations, on="molecule_id", how="left")
out = out.sort_values(["model", "molecule_id"], kind="mergesort").reset_index(drop=True)
out.attrs["warnings"] = warnings_list
return out