from __future__ import annotations
from typing import Dict, Iterable, List, Optional, Tuple
import numpy as np
def _normalize_fracs(fracs) -> Tuple[float, float, float]:
"""Return (train, valid, test) fractions from list/tuple or dict."""
if fracs is None:
return 0.8, 0.1, 0.1
if isinstance(fracs, dict):
train = float(fracs.get("train"))
valid = float(fracs.get("valid"))
test = float(fracs.get("test"))
elif isinstance(fracs, (list, tuple)):
if len(fracs) != 3:
raise ValueError("split_fracs must have three entries: [train, valid, test]")
train, valid, test = (float(fracs[0]), float(fracs[1]), float(fracs[2]))
else:
raise TypeError("split_fracs must be a list/tuple or dict with train/valid/test")
total = train + valid + test
if abs(total - 1.0) > 1e-6:
raise ValueError(f"split_fracs must sum to 1.0 (got {total:.6f})")
return train, valid, test
def _rng(seed: Optional[int]):
return np.random.RandomState(seed) if seed is not None else np.random
[docs]
def random_split_indices(
n_items: int,
frac_train: float,
frac_valid: float,
frac_test: float,
seed: Optional[int] = 123,
) -> Tuple[List[int], List[int], List[int]]:
rng = _rng(seed)
indices = np.arange(n_items)
rng.shuffle(indices)
train_cut = int(frac_train * n_items)
valid_cut = int((frac_train + frac_valid) * n_items)
train_idx = indices[:train_cut].tolist()
valid_idx = indices[train_cut:valid_cut].tolist()
test_idx = indices[valid_cut:].tolist()
return train_idx, valid_idx, test_idx
def _scaffold_smiles(smiles: str, include_chirality: bool = False) -> Optional[str]:
from rdkit import Chem
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
[docs]
def scaffold_split_indices(
smiles_list: Iterable[str],
frac_train: float,
frac_valid: float,
frac_test: float,
seed: Optional[int] = 123,
) -> Tuple[List[int], List[int], List[int]]:
"""DeepChem-style scaffold split: group by Bemis-Murcko scaffold, then size-sort."""
scaffolds: Dict[Optional[str], List[int]] = {}
for idx, smiles in enumerate(smiles_list):
scaffold = _scaffold_smiles(smiles)
scaffolds.setdefault(scaffold, []).append(idx)
# DeepChem-style: sort by scaffold set size (descending), use optional shuffle for ties.
items = list(scaffolds.items())
if seed is not None:
rng = _rng(seed)
rng.shuffle(items)
items.sort(key=lambda kv: len(kv[1]), reverse=True)
n_items = sum(len(v) for _, v in items)
train_cut = frac_train * n_items
valid_cut = (frac_train + frac_valid) * n_items
train_idx: List[int] = []
valid_idx: List[int] = []
test_idx: List[int] = []
for _, idxs in items:
if len(train_idx) + len(idxs) <= train_cut:
train_idx.extend(idxs)
elif len(train_idx) + len(valid_idx) + len(idxs) <= valid_cut:
valid_idx.extend(idxs)
else:
test_idx.extend(idxs)
return train_idx, valid_idx, test_idx
[docs]
def split_indices(
smiles_list: Iterable[str],
method: str,
fracs,
seed: Optional[int] = 123,
) -> Tuple[List[int], List[int], List[int]]:
frac_train, frac_valid, frac_test = _normalize_fracs(fracs)
method_norm = method.strip().lower()
smiles_list = list(smiles_list)
if method_norm in {"random", "rand"}:
return random_split_indices(len(smiles_list), frac_train, frac_valid, frac_test, seed=seed)
if method_norm in {"scaffold", "scaffold_split"}:
return scaffold_split_indices(smiles_list, frac_train, frac_valid, frac_test, seed=seed)
raise ValueError(f"Unknown splitter method: {method}")