from __future__ import annotations
from typing import Dict, Any, List, Optional, Sequence
import importlib
from pathlib import Path
import pandas as pd
import numpy as np
from utils.config_models import normalize_loader_config, normalize_split_column
from utils.splitting import split_indices
try:
import polaris as po
except ImportError: # pragma: no cover - optional dependency
po = None
def _coerce_label_value(value):
"""Convert label entries to Python scalars (int/float/None) when possible."""
if isinstance(value, (list, tuple, np.ndarray)):
# Flatten first level if nested collections sneaked in
return [_coerce_label_value(v) for v in value]
if value is None:
return None
if isinstance(value, (np.generic,)):
value = value.item()
if isinstance(value, str):
s = value.strip()
if s == "" or s.lower() in {"nan", "na", "null"}:
return None
try:
# Prefer ints when they fit exactly
iv = int(s)
fv = float(s)
if abs(fv - iv) < 1e-12:
return iv
return fv
except ValueError:
try:
fv = float(s)
if fv.is_integer():
return int(fv)
return fv
except ValueError:
return s
try:
if pd.isna(value):
return None
except TypeError:
pass
return value
[docs]
class BaseLoader:
def __init__(self, cfg: Dict[str, Any]):
self.cfg = normalize_loader_config(cfg)
self.info = self.cfg.get("info", {})
[docs]
def get_splits(self) -> Dict[str, pd.DataFrame]:
raise NotImplementedError
def _import_from_str(self, dotted: str):
mod, _, attr = dotted.rpartition(".")
if not mod:
raise ImportError(dotted)
return getattr(importlib.import_module(mod), attr)
def _maybe_clean(self, smiles: List[str]) -> pd.DataFrame:
# Default cleaner is SMILESCleaner unless explicitly disabled
cleaner_path = self.info.get("cleaner", "utils.cleaner.SMILESCleaner")
keep_invalid = bool(self.info.get("keep_invalid", True))
if cleaner_path and str(cleaner_path).lower() not in {"none", "null", "false"}:
Cleaner = self._import_from_str(cleaner_path)
cleaner = Cleaner(smiles)
# Prefer valid rows by default; keep all if you want to audit failures too
df = (cleaner.get_valid() if not keep_invalid else cleaner.get_data()).copy()
else:
df = pd.DataFrame({"smiles": smiles})
if "smiles_clean" not in df.columns and "smiles" in df.columns:
df = df.rename(columns={"smiles": "smiles_clean"})
return df
[docs]
class TDCLoader(BaseLoader):
def _init_dataset(self):
name = self.cfg["name"]
path = self.cfg.get("data_path", "data/tdc/")
from tdc.single_pred import ADME, Tox, HTS, QM
for cls in (ADME, Tox, HTS, QM):
try:
ds = cls(name=name, path=path)
_ = ds.y # probe
return ds
except Exception:
continue
raise RuntimeError(f"TDC dataset not found: {name}")
def _pick(self, cols: List[str], frame: pd.DataFrame) -> str:
for c in cols:
if c in frame.columns:
return c
raise KeyError(f"none of {cols} in {list(frame.columns)}")
[docs]
def get_splits(self) -> Dict[str, pd.DataFrame]:
ds = self._init_dataset()
method = self.info.get("split")
raw = ds.get_split(method=method) if method else ds.get_split()
out = {}
for split in ("train", "valid", "test"):
part = raw[split]
smiles_col = self.info.get("smiles_col") or self._pick(["Drug", "SMILES", "smiles"], part)
label_col = self.info.get("label_col") or self._pick(["Y", "y", "label"], part)
id_col = self.info.get("id_col") if self.info.get("id_col") in part.columns else None
df_clean = self._maybe_clean(part[smiles_col].tolist())
df_clean["label_raw"] = part[label_col].tolist()
if id_col:
df_clean["id"] = part[id_col].tolist()
out[split] = df_clean
return out
[docs]
class TabularLoader(BaseLoader):
DEFAULT_SMILES_COLS = ["smiles", "SMILES", "drug", "Drug"]
DEFAULT_LABEL_COLS = ["label_raw", "label", "Label", "y", "Y"]
DEFAULT_ID_COLS = ["id", "ID", "compound_id", "compoundID"]
DEFAULT_SEQUENCE_COLS = [
"sequence_aa",
"sequence",
"Sequence",
"protein_sequence",
"ProteinSequence",
"target_sequence",
"TargetSequence",
"AASequence",
]
DEFAULT_TARGET_ID_COLS = [
"target_id",
"target",
"TargetID",
"protein_id",
"ProteinID",
]
def _read_like(self, path: Path) -> pd.DataFrame:
s = path.suffix.lower()
if s in {".csv", ".tsv"}:
return pd.read_csv(path, sep="," if s == ".csv" else "\t")
if s == ".parquet":
return pd.read_parquet(path)
raise ValueError(f"unsupported file: {path}")
def _resolve_column(self, df: pd.DataFrame, key: str, candidates: List[str]) -> Optional[str]:
info_val = self.info.get(key)
if info_val and info_val in df.columns:
return info_val
lower_map = {col.lower(): col for col in df.columns}
for cand in candidates:
if cand in df.columns:
return cand
cli = cand.lower()
if cli in lower_map:
return lower_map[cli]
return None
def _standardize_cols(self, df: pd.DataFrame) -> pd.DataFrame:
smiles_col = self._resolve_column(df, "smiles_col", self.DEFAULT_SMILES_COLS)
if smiles_col and smiles_col != "smiles":
df = df.rename(columns={smiles_col: "smiles"})
if "smiles" not in df.columns:
raise KeyError("Could not determine SMILES column. Set info.smiles_col in the config.")
label_cols_cfg = self.info.get("label_cols")
normalized_label_cols: Optional[List[str]] = None
if label_cols_cfg:
if isinstance(label_cols_cfg, str):
normalized_label_cols = [label_cols_cfg]
elif isinstance(label_cols_cfg, Sequence):
normalized_label_cols = [str(c) for c in label_cols_cfg]
else:
raise TypeError("info.label_cols must be a string or a list of strings")
missing = [c for c in normalized_label_cols if c not in df.columns]
if missing:
raise KeyError(f"Missing label columns {missing} in dataframe.")
self.info["label_cols"] = normalized_label_cols
df["label_raw"] = df[normalized_label_cols].apply(
lambda row: [_coerce_label_value(row[col]) for col in normalized_label_cols],
axis=1,
)
else:
label_col = self._resolve_column(df, "label_col", self.DEFAULT_LABEL_COLS)
if label_col and label_col != "label_raw":
df = df.rename(columns={label_col: "label_raw"})
if "label_raw" not in df.columns:
raise KeyError("Could not determine label column. Set info.label_col in the config.")
df["label_raw"] = df["label_raw"].apply(_coerce_label_value)
id_col = self._resolve_column(df, "id_col", self.DEFAULT_ID_COLS)
if id_col and id_col != "id":
df = df.rename(columns={id_col: "id"})
seq_col = self._resolve_column(df, "sequence_col", self.DEFAULT_SEQUENCE_COLS)
if seq_col and seq_col != "sequence_aa":
df = df.rename(columns={seq_col: "sequence_aa"})
target_col = self._resolve_column(df, "target_id_col", self.DEFAULT_TARGET_ID_COLS)
if target_col and target_col != "target_id":
df = df.rename(columns={target_col: "target_id"})
return df
[docs]
def get_splits(self) -> Dict[str, pd.DataFrame]:
# three files
paths_cfg = self.cfg.get("paths")
if paths_cfg is not None:
out = {}
for split in ("train", "valid", "test"):
df = self._read_like(Path(paths_cfg[split]))
df = self._standardize_cols(df)
df_clean = self._maybe_clean(df["smiles"].tolist())
df_clean["label_raw"] = df["label_raw"].tolist()
if "id" in df.columns:
df_clean["id"] = df["id"].tolist()
if "sequence_aa" in df.columns:
if len(df_clean) != len(df):
raise ValueError("Sequence-aware tabular loader expects keep_invalid=True to retain row alignment.")
df_clean["sequence_aa"] = df["sequence_aa"].tolist()
if "target_id" in df.columns:
if len(df_clean) != len(df):
raise ValueError("Sequence-aware tabular loader expects keep_invalid=True to retain row alignment.")
df_clean["target_id"] = df["target_id"].tolist()
out[split] = df_clean
return out
# single file + split column
path_cfg = self.cfg.get("path")
if path_cfg is not None:
df = self._read_like(Path(path_cfg))
df = self._standardize_cols(df)
split_col = self.info.get("split_col", "split")
if split_col not in df.columns:
split_method = self.info.get("split_method")
if not split_method:
raise KeyError(f"missing split_col '{split_col}' in {path_cfg}")
split_fracs = self.info.get("split_fracs", [0.8, 0.1, 0.1])
split_seed = self.info.get("split_seed", 123)
train_idx, valid_idx, test_idx = split_indices(
df["smiles"].tolist(),
method=split_method,
fracs=split_fracs,
seed=split_seed,
)
df[split_col] = "train"
df.loc[valid_idx, split_col] = "valid"
df.loc[test_idx, split_col] = "test"
df[split_col] = normalize_split_column(df[split_col])
out = {}
for split in ("train", "valid", "test"):
part = df[df[split_col] == split]
if part.empty:
raise ValueError(f"no rows for split '{split}' in {path_cfg}")
df_clean = self._maybe_clean(part["smiles"].tolist())
df_clean["label_raw"] = part["label_raw"].tolist()
if "id" in part.columns:
df_clean["id"] = part["id"].tolist()
if "sequence_aa" in part.columns:
if len(df_clean) != len(part):
raise ValueError("Sequence-aware tabular loader expects keep_invalid=True to retain row alignment.")
df_clean["sequence_aa"] = part["sequence_aa"].tolist()
if "target_id" in part.columns:
if len(df_clean) != len(part):
raise ValueError("Sequence-aware tabular loader expects keep_invalid=True to retain row alignment.")
df_clean["target_id"] = part["target_id"].tolist()
out[split] = df_clean
return out
raise ValueError("tabular loader needs 'paths' or 'path'")
[docs]
class PolarisLoader(BaseLoader):
"""Minimal Polaris loader.
Expects cfg = {"type": "polaris", "name": "<vendor/benchmark-id>"}.
Returns only {'train', 'test'} with columns: smiles_clean, label_raw, id.
"""
[docs]
def get_splits(self) -> Dict[str, pd.DataFrame]:
if po is None:
raise ImportError(
"polaris-lib is required for Polaris datasets. Install with 'pip install polaris-lib'."
)
bench = po.load_benchmark(self.cfg["name"])
train, test = bench.get_train_test_split()
def _to_df(loader) -> pd.DataFrame:
smiles = loader.inputs
try:
y = loader.targets # TODO: Handle multitask labels...
except:
y = [None] * len(smiles)
if smiles is None or y is None:
raise ValueError("Missing SMILES or labels")
df = self._maybe_clean(smiles)
df["label_raw"] = y
if "id" not in df.columns:
df["id"] = np.arange(len(df), dtype=np.int64)
return df
return {"train": _to_df(train), "test": _to_df(test)}
[docs]
class DTILoader(TabularLoader):
"""DTI loader built on TabularLoader with sensible defaults."""
def __init__(self, cfg: Dict[str, Any]):
super().__init__(cfg)
if "keep_invalid" not in self.info:
# Keep invalid molecules so that auxiliary columns (sequence/target) remain aligned.
self.info["keep_invalid"] = True
def _standardize_cols(self, df: pd.DataFrame) -> pd.DataFrame:
df = super()._standardize_cols(df)
if "sequence_aa" not in df.columns:
raise KeyError(
"DTI loader requires an amino-acid sequence column. "
"Set info.sequence_col or name the column one of "
f"{self.DEFAULT_SEQUENCE_COLS}."
)
return df