from __future__ import annotations
import copy
import json
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Any, List, Optional, Set, Tuple, Literal
import logging
import numpy as np
import pandas as pd
from rdkit import DataStructs
from rdkit.Chem import AllChem, rdMolDescriptors
from rdkit import Chem
try:
from Levenshtein import ratio as _levenshtein # type: ignore
except ImportError: # pragma: no cover - fallback when python-Levenshtein is absent
from difflib import SequenceMatcher
def _levenshtein(a: str, b: str) -> float:
return SequenceMatcher(None, a, b).ratio()
from rdkit.Chem.Scaffolds.MurckoScaffold import (
MakeScaffoldGeneric as _GraphFramework,
GetScaffoldForMol as _GetScaffoldForMol,
)
try:
import psa
except ImportError: # pragma: no cover - optional dependency
psa = None
from utils.pydantic_compat import ConfigDict, Field, HAVE_PYDANTIC, PydanticBaseModel
def _to_python_scalar(val: Any) -> Any:
if isinstance(val, (list, tuple, np.ndarray, pd.Series)):
return [_to_python_scalar(v) for v in list(val)]
if isinstance(val, (np.generic,)):
return val.item()
if isinstance(val, (pd.Timestamp,)):
return val.to_pydatetime()
try:
if pd.isna(val):
return None
except TypeError:
pass
return val
def _is_sequence_label(value: Any) -> bool:
return isinstance(value, (list, tuple, np.ndarray, pd.Series))
def _normalize_label_list(value: Any) -> List[Any]:
if isinstance(value, pd.Series):
value = value.tolist()
if isinstance(value, np.ndarray):
value = value.tolist()
if isinstance(value, (list, tuple)):
seq = list(value)
else:
seq = [value]
out: List[Any] = []
for item in seq:
if isinstance(item, (list, tuple, np.ndarray, pd.Series)):
out.extend(_normalize_label_list(item))
continue
if item is None:
out.append(None)
continue
if isinstance(item, (np.generic,)):
item = item.item()
if isinstance(item, str):
s = item.strip()
if s == "" or s.lower() in {"nan", "na", "null"}:
out.append(None)
continue
try:
fv = float(s)
if fv.is_integer():
out.append(int(fv))
else:
out.append(fv)
continue
except ValueError:
out.append(s)
continue
try:
if pd.isna(item):
out.append(None)
continue
except TypeError:
pass
out.append(item)
return out
def _label_to_tuple(value: Any) -> Tuple[Any, ...]:
return tuple(_normalize_label_list(value))
def _series_has_sequence_labels(series: pd.Series) -> bool:
if series.empty:
return False
return series.apply(_is_sequence_label).any()
def _label_series_to_matrix(series: pd.Series) -> np.ndarray:
if series.empty:
return np.zeros((0, 0), dtype=float)
rows = [_normalize_label_list(v) for v in series.tolist()]
max_len = max((len(row) for row in rows), default=0)
if max_len == 0:
return np.zeros((len(rows), 0), dtype=float)
arr = np.full((len(rows), max_len), np.nan, dtype=float)
for i, row in enumerate(rows):
for j, val in enumerate(row):
if val is None:
arr[i, j] = np.nan
else:
try:
arr[i, j] = float(val)
except (TypeError, ValueError):
arr[i, j] = np.nan
return arr
def _classification_delta(a: Tuple[Any, ...], b: Tuple[Any, ...]) -> List[Optional[float]]:
length = max(len(a), len(b))
out: List[Optional[float]] = []
for idx in range(length):
va = a[idx] if idx < len(a) else None
vb = b[idx] if idx < len(b) else None
if va is None or vb is None:
out.append(None)
continue
try:
out.append(int(va) - int(vb))
except (TypeError, ValueError):
try:
out.append(float(va) - float(vb))
except (TypeError, ValueError):
out.append(None)
return out
def _regression_delta(a: Any, b: Any) -> List[Optional[float]]:
seq_a = _normalize_label_list(a)
seq_b = _normalize_label_list(b)
length = max(len(seq_a), len(seq_b))
out: List[Optional[float]] = []
for idx in range(length):
va = seq_a[idx] if idx < len(seq_a) else None
vb = seq_b[idx] if idx < len(seq_b) else None
if va is None or vb is None:
out.append(None)
continue
try:
out.append(float(va) - float(vb))
except (TypeError, ValueError):
out.append(None)
return out
def _delta_exceeds_threshold(delta: List[Optional[float]], sigma) -> bool:
if not delta:
return False
if isinstance(sigma, (list, tuple, np.ndarray)):
sigmas = list(sigma)
else:
sigmas = [sigma] * len(delta)
if not sigmas:
return False
for idx, d in enumerate(delta):
if d is None:
continue
thr = sigmas[idx] if idx < len(sigmas) else sigmas[-1]
if thr is None or thr == 0:
continue
try:
if abs(float(d)) >= float(thr):
return True
except (TypeError, ValueError):
continue
return False
def _validate_analyzer_config_values(cfg: "AnalyzerConfig") -> None:
if float(cfg.sim_threshold) < 0 or float(cfg.sim_threshold) > 1:
raise ValueError("AnalyzerConfig.sim_threshold must be between 0 and 1")
if int(cfg.fp_radius) <= 0:
raise ValueError("AnalyzerConfig.fp_radius must be > 0")
if int(cfg.fp_nbits) <= 0:
raise ValueError("AnalyzerConfig.fp_nbits must be > 0")
if cfg.label_cols is not None:
if isinstance(cfg.label_cols, tuple):
cfg.label_cols = list(cfg.label_cols)
if not isinstance(cfg.label_cols, list):
raise TypeError("AnalyzerConfig.label_cols must be a list of strings")
cfg.label_cols = [str(col) for col in cfg.label_cols]
if HAVE_PYDANTIC: # pragma: no cover - exercised when pydantic is installed
_ANALYZER_CONFIG_FIELD_ORDER = (
"task_type",
"typ",
"sim_threshold",
"fp_radius",
"fp_nbits",
"smiles_col",
"label_col",
"id_col",
"label_cols",
"sequence_col",
"target_id_col",
"name",
"unique_sequences_jsonl",
"foldseek_m8_path",
)
_ANALYSIS_RESULT_FIELD_ORDER = (
"summary",
"per_record_df",
"conflicts_rows",
"cliffs_rows",
"sequence_alignment_rows",
"structure_alignment_rows",
)
def _merge_positional_fields(field_names: tuple[str, ...], args: tuple[Any, ...], data: dict[str, Any]) -> dict[str, Any]:
if len(args) > len(field_names):
raise TypeError(f"Too many positional arguments (expected at most {len(field_names)})")
merged = dict(data)
for name, value in zip(field_names, args):
if name in merged:
raise TypeError(f"Multiple values for argument '{name}'")
merged[name] = value
return merged
class AnalyzerConfig(PydanticBaseModel):
"""Minimal, YAML-friendly config for SMILES analysis."""
task_type: Literal["classification", "regression"]
typ: Literal["tdc", "tabular", "polaris"]
sim_threshold: float = Field(default=0.9, ge=0.0, le=1.0) if Field is not None else 0.9
fp_radius: int = Field(default=2, gt=0) if Field is not None else 2
fp_nbits: int = Field(default=2048, gt=0) if Field is not None else 2048
smiles_col: Optional[str] = None
label_col: Optional[str] = None
id_col: Optional[str] = None
label_cols: Optional[List[str]] = None
sequence_col: Optional[str] = None
target_id_col: Optional[str] = None
name: Optional[str] = None
unique_sequences_jsonl: Optional[str] = None
foldseek_m8_path: Optional[str] = None
if ConfigDict is not None: # pydantic v2
model_config = ConfigDict(validate_assignment=True)
else: # pydantic v1
class Config:
validate_assignment = True
def __init__(self, *args, **data):
data = _merge_positional_fields(_ANALYZER_CONFIG_FIELD_ORDER, args, data)
super().__init__(**data)
_validate_analyzer_config_values(self)
class AnalysisResult(PydanticBaseModel):
summary: Dict[str, Any]
per_record_df: pd.DataFrame
conflicts_rows: List[Dict[str, Any]]
cliffs_rows: List[Dict[str, Any]]
sequence_alignment_rows: Optional[List[Dict[str, Any]]] = None
structure_alignment_rows: Optional[List[Dict[str, Any]]] = None
if ConfigDict is not None: # pydantic v2
model_config = ConfigDict(arbitrary_types_allowed=True)
else: # pydantic v1
class Config:
arbitrary_types_allowed = True
def __init__(self, *args, **data):
data = _merge_positional_fields(_ANALYSIS_RESULT_FIELD_ORDER, args, data)
super().__init__(**data)
else:
[docs]
@dataclass
class AnalyzerConfig:
"""Minimal, YAML-friendly config for SMILES analysis."""
task_type: Literal["classification", "regression"]
typ: Literal["tdc", "tabular", "polaris"]
sim_threshold: float = 0.9
fp_radius: int = 2
fp_nbits: int = 2048
smiles_col: Optional[str] = None
label_col: Optional[str] = None
id_col: Optional[str] = None
label_cols: Optional[List[str]] = None
sequence_col: Optional[str] = None
target_id_col: Optional[str] = None
name: Optional[str] = None
unique_sequences_jsonl: Optional[str] = None
foldseek_m8_path: Optional[str] = None
def __post_init__(self):
_validate_analyzer_config_values(self)
[docs]
@dataclass
class AnalysisResult:
summary: Dict[str, Any]
per_record_df: pd.DataFrame
conflicts_rows: List[Dict[str, Any]]
cliffs_rows: List[Dict[str, Any]]
sequence_alignment_rows: Optional[List[Dict[str, Any]]] = None
structure_alignment_rows: Optional[List[Dict[str, Any]]] = None
def _normalize_columns(df: pd.DataFrame, cfg: AnalyzerConfig, split: str) -> pd.DataFrame:
"""Ensure expected columns exist: "smiles_clean", "label_raw", "id", "split"."""
out = df.copy()
# Rename if needed
if "smiles_clean" not in out.columns and cfg.smiles_col in out.columns:
out = out.rename(columns={cfg.smiles_col: "smiles_clean"})
if "label_raw" not in out.columns and cfg.label_col in out.columns:
out = out.rename(columns={cfg.label_col: "label_raw"})
if "id" not in out.columns and cfg.id_col in out.columns:
out = out.rename(columns={cfg.id_col: "id"})
# Minimal guards
if "smiles_clean" not in out.columns:
raise ValueError("Missing 'smiles_clean' (or cfg.smiles_col). Provide cleaned SMILES or set config columns.")
if "label_raw" not in out.columns:
raise ValueError("Missing 'label_raw' (or cfg.label_col). Provide labels or set config columns.")
# Assign ids if needed
if "id" not in out.columns:
out["id"] = np.arange(len(out), dtype=np.int64)
out["split"] = split
return out
[docs]
def morgan_fps(smiles_list: List[str], radius: int, n_bits: int) -> List[Optional[DataStructs.ExplicitBitVect]]:
"""Compute Morgan/ECFP fingerprints. Returns None for invalid SMILES."""
fps: List[Optional[DataStructs.ExplicitBitVect]] = []
for smi in smiles_list:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
fps.append(None)
else:
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
fps.append(fp)
except Exception:
fps.append(None)
return fps
def _scaffold_fp_for_mol(mol: Optional[Chem.Mol], radius: int, n_bits: int) -> Optional[DataStructs.ExplicitBitVect]:
if mol is None:
return None
try:
try:
scaffold = _GraphFramework(mol)
except Exception:
scaffold = _GetScaffoldForMol(mol)
if scaffold is None:
return None
return AllChem.GetMorganFingerprintAsBitVect(scaffold, radius=radius, nBits=n_bits)
except Exception:
return None
[docs]
def scaffold_fps(smiles_list: List[str], radius: int, n_bits: int) -> List[Optional[DataStructs.ExplicitBitVect]]:
fps: List[Optional[DataStructs.ExplicitBitVect]] = []
for smi in smiles_list:
mol = Chem.MolFromSmiles(smi)
fps.append(_scaffold_fp_for_mol(mol, radius, n_bits))
return fps
def _safe_exact_mol_wt(smiles: Any) -> Optional[float]:
"""Return exact molecular weight for valid SMILES, otherwise None."""
try:
if smiles is None or (isinstance(smiles, float) and np.isnan(smiles)):
return None
mol = Chem.MolFromSmiles(str(smiles))
if mol is None:
return None
return float(rdMolDescriptors.CalcExactMolWt(mol))
except Exception:
return None
def _nn_tanimoto_stats(
src_fps: List[Optional[DataStructs.ExplicitBitVect]],
qry_fps: List[Optional[DataStructs.ExplicitBitVect]],
) -> Dict[str, Any]:
"""For each query fp, compute max Tanimoto vs any source fp; return mean/std over queries.
If no valid comparisons exist, returns {'mean': None, 'std': None, 'n': 0}.
"""
src_valid = [fp for fp in src_fps if fp is not None]
if len(src_valid) == 0:
return {"mean": None, "std": None, "n": 0}
nn_vals: List[float] = []
for q in qry_fps:
if q is None:
continue
sims = DataStructs.BulkTanimotoSimilarity(q, src_valid)
if len(sims) == 0:
continue
nn_vals.append(max(sims))
if len(nn_vals) == 0:
return {"mean": None, "std": None, "n": 0}
arr = np.asarray(nn_vals, dtype=float)
return {"mean": float(arr.mean()), "std": float(arr.std()) if len(arr) > 1 else 0.0, "n": int(len(arr))}
def _pairs_above_thresh(
fpsA: List[Optional[DataStructs.ExplicitBitVect]],
fpsB: List[Optional[DataStructs.ExplicitBitVect]],
scafA: List[Optional[DataStructs.ExplicitBitVect]],
scafB: List[Optional[DataStructs.ExplicitBitVect]],
smiA: List[str],
smiB: List[str],
thr: float,
intra: bool,
) -> Set[Tuple[int, int]]:
"""Find index pairs (i, j) that are "similar" by MoleculeACE consensus:
- molecular ECFP Tanimoto >= thr OR
- scaffold ECFP Tanimoto >= thr OR
- normalized SMILES Levenshtein similarity >= thr
Intra => enforce j > i to avoid dup/self.
"""
pairs: Set[Tuple[int, int]] = set()
for i, (fa, fsa, sa) in enumerate(zip(fpsA, scafA, smiA)):
j_start = i + 1 if intra else 0
if fa is None and fsa is None:
pass
for j in range(j_start, len(fpsB)):
fb, fsb, sb = fpsB[j], scafB[j], smiB[j]
# skip if both smiles are identical (handled as conflicts elsewhere)
if intra and sa == sb:
continue
tani_ok = False
leve_ok = False
scaf_ok = False
if fa is not None and fb is not None:
tani_ok = DataStructs.TanimotoSimilarity(fa, fb) >= thr
leve_ok = _levenshtein(sa, sb) >= thr
if fsa is not None and fsb is not None:
scaf_ok = DataStructs.TanimotoSimilarity(fsa, fsb) >= thr
if tani_ok or scaf_ok or leve_ok:
# print(f"Pair found {tani_ok} {scaf_ok} {leve_ok} {thr} {sa} {sb}")
pairs.add((i, j))
return pairs
[docs]
@dataclass(frozen=True)
class StretcherAlignment:
score: float
identity_pct: float
similarity_pct: float
length: int
gaps_pct: float
n_gaps: int
aligned_query: str
aligned_subject: str
query_start: int
query_end: int
subject_start: int
subject_end: int
[docs]
class PSAStretcherAligner:
"""Thin wrapper around psa.stretcher with caching."""
def __init__(self):
if psa is None:
raise ImportError(
"pairwise-sequence-alignment is required. "
"Install EMBOSS (`sudo apt install emboss`) and the Python wrapper "
"(`pip install pairwise-sequence-alignment`) before running DTI analysis."
)
self.moltype = "prot"
self._cache: Dict[Tuple[str, str], StretcherAlignment] = {}
def _normalize_seq(self, seq: str) -> str:
if not isinstance(seq, str):
seq = "" if pd.isna(seq) else str(seq)
return "".join(seq.split()).upper()
def _invert_alignment(self, aln: StretcherAlignment) -> StretcherAlignment:
return StretcherAlignment(
score=aln.score,
identity_pct=aln.identity_pct,
similarity_pct=aln.similarity_pct,
length=aln.length,
gaps_pct=aln.gaps_pct,
n_gaps=aln.n_gaps,
aligned_query=aln.aligned_subject,
aligned_subject=aln.aligned_query,
query_start=aln.subject_start,
query_end=aln.subject_end,
subject_start=aln.query_start,
subject_end=aln.query_end,
)
def _empty_alignment(self, q: str, s: str) -> StretcherAlignment:
length = max(len(q), len(s))
return StretcherAlignment(
score=0.0,
identity_pct=0.0,
similarity_pct=0.0,
length=length,
gaps_pct=100.0 if length > 0 else 0.0,
n_gaps=length,
aligned_query=q or "-" * length,
aligned_subject=s or "-" * length,
query_start=1 if q else 0,
query_end=len(q),
subject_start=1 if s else 0,
subject_end=len(s),
)
[docs]
def align(self, query_seq: str, subject_seq: str) -> StretcherAlignment:
q = self._normalize_seq(query_seq)
s = self._normalize_seq(subject_seq)
key = (q, s)
if key in self._cache:
return self._cache[key]
rev_key = (s, q)
if rev_key in self._cache:
flipped = self._invert_alignment(self._cache[rev_key])
self._cache[key] = flipped
return flipped
if not q or not s:
empty_aln = self._empty_alignment(q, s)
self._cache[key] = empty_aln
self._cache[rev_key] = self._invert_alignment(empty_aln)
return empty_aln
aln = psa.stretcher(
moltype=self.moltype,
qseq=q,
sseq=s,
)
result = StretcherAlignment(
score=float(aln.score),
identity_pct=float(aln.pidentity),
similarity_pct=float(getattr(aln, "psimilarity", float("nan"))),
length=int(aln.length),
gaps_pct=float(aln.pgaps),
n_gaps=int(aln.ngaps),
aligned_query=str(aln.qseq),
aligned_subject=str(aln.sseq),
query_start=int(aln.qstart),
query_end=int(aln.qend),
subject_start=int(aln.sstart),
subject_end=int(aln.send),
)
self._cache[key] = result
self._cache[rev_key] = self._invert_alignment(result)
return result
def _nn_sequence_alignment_stats(
ref_sequences: Set[str],
qry_sequences: Set[str],
aligner: PSAStretcherAligner,
ref_split: str,
qry_split: str,
) -> Tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
if not ref_sequences or not qry_sequences:
return (
{
"mean_identity_pct": None,
"std_identity_pct": None,
"mean_similarity_pct": None,
"std_similarity_pct": None,
"mean_score": None,
"std_score": None,
"n": 0,
},
[],
)
ref_list = sorted(ref_sequences)
qry_list = sorted(qry_sequences)
identities: List[float] = []
scores: List[float] = []
similarities: List[float] = []
details: List[Dict[str, Any]] = []
for seq_q in qry_list:
best_alignment: Optional[StretcherAlignment] = None
best_seq = None
for seq_r in ref_list:
aln = aligner.align(seq_q, seq_r)
if (
best_alignment is None
or aln.identity_pct > best_alignment.identity_pct
or (
aln.identity_pct == best_alignment.identity_pct
and aln.score > best_alignment.score
)
):
best_alignment = aln
best_seq = seq_r
if best_alignment is not None and best_alignment.identity_pct >= 99.99:
break
if best_alignment is None or best_seq is None:
continue
identities.append(float(best_alignment.identity_pct))
similarities.append(float(best_alignment.similarity_pct))
scores.append(float(best_alignment.score))
details.append({
"split_reference": ref_split,
"split_query": qry_split,
"sequence_reference": best_seq,
"sequence_query": seq_q,
"identity_pct": float(best_alignment.identity_pct),
"similarity_pct": float(best_alignment.similarity_pct),
"score": float(best_alignment.score),
"alignment_length": int(best_alignment.length),
"gaps_pct": float(best_alignment.gaps_pct),
"n_gaps": int(best_alignment.n_gaps),
"aligned_query": best_alignment.aligned_query,
"aligned_reference": best_alignment.aligned_subject,
"query_start": int(best_alignment.query_start),
"query_end": int(best_alignment.query_end),
"reference_start": int(best_alignment.subject_start),
"reference_end": int(best_alignment.subject_end),
})
if not identities:
return (
{
"mean_identity_pct": None,
"std_identity_pct": None,
"mean_similarity_pct": None,
"std_similarity_pct": None,
"mean_score": None,
"std_score": None,
"n": 0,
},
details,
)
arr_id = np.asarray(identities, dtype=float)
arr_score = np.asarray(scores, dtype=float)
arr_sim = np.asarray(similarities, dtype=float)
finite_sim = arr_sim[~np.isnan(arr_sim)]
if finite_sim.size > 0:
mean_sim = float(finite_sim.mean())
std_sim = float(finite_sim.std()) if finite_sim.size > 1 else 0.0
else:
mean_sim = None
std_sim = None
stats = {
"mean_identity_pct": float(arr_id.mean()),
"std_identity_pct": float(arr_id.std()) if len(arr_id) > 1 else 0.0,
"mean_similarity_pct": mean_sim,
"std_similarity_pct": std_sim,
"mean_score": float(arr_score.mean()),
"std_score": float(arr_score.std()) if len(arr_score) > 1 else 0.0,
"n": int(len(arr_id)),
}
return stats, details
def _compute_sigma3(tv_labels: pd.Series):
"""Return (3 * std, std) for regression labels. Supports scalar or vector labels."""
if tv_labels.empty:
return 0.0, 0.0
if _series_has_sequence_labels(tv_labels):
matrix = _label_series_to_matrix(tv_labels)
if matrix.size == 0 or matrix.shape[0] < 2:
return [0.0] * matrix.shape[1], [0.0] * matrix.shape[1]
std = np.nanstd(matrix, axis=0, ddof=1)
sigma3 = (std * 3.0).tolist()
return sigma3, std.tolist()
numeric = pd.to_numeric(tv_labels, errors="coerce").dropna()
if len(numeric) < 2:
return 0.0, 0.0
std = float(np.std(numeric.astype(float), ddof=1))
return 3.0 * std, std
def _intra_conflict_smiles(df: pd.DataFrame, is_cls: bool, sigma3: Optional[float]) -> Set[str]:
"""Find same-SMILES conflicts within a single split."""
conflicts: Set[str] = set()
if df is None or df.empty:
return conflicts
for smi, g in df.groupby("smiles_clean"):
if g.empty:
continue
ys = g["label_raw"].tolist()
if is_cls:
tuples = [_label_to_tuple(v) for v in ys]
if len(set(tuples)) > 1:
conflicts.add(smi)
else:
if len(ys) >= 2 and sigma3 is not None:
found = False
for i in range(len(ys)):
if found:
break
for j in range(i + 1, len(ys)):
delta = _regression_delta(ys[i], ys[j])
if _delta_exceeds_threshold(delta, sigma3):
conflicts.add(smi)
found = True
break
return conflicts
def _cross_conflict_smiles(A: pd.DataFrame, B: pd.DataFrame, is_cls: bool, sigma3: Optional[float]) -> Set[str]:
"""Find same-SMILES conflicts across two splits."""
merged = A[["smiles_clean", "label_raw"]].merge(
B[["smiles_clean", "label_raw"]], on="smiles_clean", suffixes=("_A", "_B")
)
if merged.empty:
return set()
if is_cls:
tuples_a = merged["label_raw_A"].apply(_label_to_tuple)
tuples_b = merged["label_raw_B"].apply(_label_to_tuple)
mask = [a != b for a, b in zip(tuples_a.tolist(), tuples_b.tolist())]
return set(merged.loc[mask, "smiles_clean"].unique())
else:
if sigma3 is None:
return set()
mask = [
_delta_exceeds_threshold(_regression_delta(a, b), sigma3)
for a, b in zip(merged["label_raw_A"].tolist(), merged["label_raw_B"].tolist())
]
return set(merged.loc[mask, "smiles_clean"].unique())
def _build_conflict_rows(tag: str, smi_set: Set[str], *dfs: pd.DataFrame) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for df in dfs:
if df.empty:
continue
sub = df[df["smiles_clean"].isin(smi_set)]
for _, r in sub.iterrows():
rows.append({
"kind": tag,
"split": r["split"],
"id": _to_python_scalar(r["id"]),
"smiles_clean": r["smiles_clean"],
"label_raw": r["label_raw"],
})
return rows
def _cliff_pairs(
A: pd.DataFrame,
B: pd.DataFrame,
fpsA: List[Optional[DataStructs.ExplicitBitVect]],
fpsB: List[Optional[DataStructs.ExplicitBitVect]],
scafA: List[Optional[DataStructs.ExplicitBitVect]],
scafB: List[Optional[DataStructs.ExplicitBitVect]],
thr: float,
is_cls: bool,
sigma3: Optional[float],
intra: bool,
) -> List[Dict[str, Any]]:
"""Find cliff pairs between A and B and return detailed rows.
Similarity uses MoleculeACE consensus (see _pairs_above_thresh):
- molecular ECFP Tanimoto
- scaffold (generic Murcko) ECFP Tanimoto
- normalized SMILES Levenshtein similarity
A row is emitted only if labels differ (classification) or |Δ| ≥ 3σ (regression).
"""
rows: List[Dict[str, Any]] = []
smiA = A["smiles_clean"].tolist()
smiB = B["smiles_clean"].tolist()
# Candidate similar pairs by consensus (any metric ≥ thr)
pairs = _pairs_above_thresh(fpsA, fpsB, scafA, scafB, smiA, smiB, thr, intra=intra)
for i, j in pairs:
ra, rb = A.iloc[i], B.iloc[j]
# Decide if this pair is a cliff
if is_cls:
labels_a = _label_to_tuple(ra["label_raw"])
labels_b = _label_to_tuple(rb["label_raw"])
if labels_a == labels_b:
continue
delta = _classification_delta(labels_a, labels_b)
else:
if sigma3 is None:
continue
delta = _regression_delta(ra["label_raw"], rb["label_raw"])
if not _delta_exceeds_threshold(delta, sigma3):
continue
fa, fb = fpsA[i], fpsB[j]
fsa, fsb = scafA[i], scafB[j]
tanimoto = (
float(DataStructs.TanimotoSimilarity(fa, fb))
if (fa is not None and fb is not None) else float("nan")
)
scaffold_tanimoto = (
float(DataStructs.TanimotoSimilarity(fsa, fsb))
if (fsa is not None and fsb is not None) else float("nan")
)
levenshtein_sim = float(_levenshtein(ra["smiles_clean"], rb["smiles_clean"]))
rows.append({
"kind": "intra" if intra and A is B else "cross",
"id_A": _to_python_scalar(ra["id"]),
"split_A": ra["split"],
"smiles_A": ra["smiles_clean"],
"y_A": ra["label_raw"],
"id_B": _to_python_scalar(rb["id"]),
"split_B": rb["split"],
"smiles_B": rb["smiles_clean"],
"y_B": rb["label_raw"],
"tanimoto": tanimoto, # molecular ECFP Tanimoto
"scaffold_tanimoto": scaffold_tanimoto, # generic Murcko ECFP Tanimoto
"levenshtein_sim": levenshtein_sim, # SMILES Levenshtein ratio
"delta": _to_python_scalar(delta),
})
return rows
[docs]
class SMILESAnalyzer:
"""Simplified, modular SMILES analyzer (MoleculeACE-style similarity)."""
def __init__(self, cfg: AnalyzerConfig, logger: Optional[logging.Logger] = None):
self.cfg = cfg
self.log = logger or logging.getLogger(__name__)
def _featurize_for_similarity(self, smiles: List[str]):
"""Return tuple of (molecular_fps, scaffold_fps) using same fp settings."""
mol_fps = morgan_fps(smiles, self.cfg.fp_radius, self.cfg.fp_nbits)
scaf_fps = scaffold_fps(smiles, self.cfg.fp_radius, self.cfg.fp_nbits)
return mol_fps, scaf_fps
[docs]
def run(self, splits_raw: Dict[str, pd.DataFrame]) -> AnalysisResult:
self.log.info("Starting SMILES analysis.")
# Normalize columns + tag split
splits: Dict[str, pd.DataFrame] = {}
for split in splits_raw.keys():
splits[split] = _normalize_columns(splits_raw[split], self.cfg, split)
self.log.info("Split %s: n=%d", split, len(splits[split]))
try:
train_df, valid_df, test_df = splits["train"], splits["valid"], splits["test"]
tv_df = pd.concat([train_df, valid_df], ignore_index=True)
except KeyError:
train_df, test_df = splits["train"], splits["test"]
tv_df = train_df
valid_df = None
is_cls = (self.cfg.task_type == "classification")
# Hygiene: duplicates/contamination + REOS stats if present
all_smiles = tv_df["smiles_clean"].tolist() + test_df["smiles_clean"].tolist()
n_dup = len(all_smiles) - len(set(all_smiles))
contaminated = set(tv_df["smiles_clean"]) & set(test_df["smiles_clean"])
if "n_reos_warnings" in tv_df.columns:
reos_mean = float(pd.to_numeric(tv_df["n_reos_warnings"], errors="coerce").fillna(0).mean())
reos_std = float(pd.to_numeric(tv_df["n_reos_warnings"], errors="coerce").fillna(0).std(ddof=1)) if len(tv_df) > 1 else 0.0
else:
reos_mean, reos_std = 0.0, 0.0
self.log.info(
"Hygiene: duplicates=%d contaminated=%d reos_mean=%.3f reos_std=%.3f",
n_dup, len(contaminated), reos_mean, reos_std,
)
# Statistics
label_series = tv_df["label_raw"]
if _series_has_sequence_labels(label_series):
label_matrix = _label_series_to_matrix(label_series)
if label_matrix.size == 0:
tv_mean: Any = []
base_std: Any = []
else:
tv_mean = np.nanmean(label_matrix, axis=0).tolist()
if label_matrix.shape[0] > 1:
base_std = np.nanstd(label_matrix, axis=0, ddof=1).tolist()
else:
base_std = [0.0] * label_matrix.shape[1]
else:
numeric = pd.to_numeric(label_series, errors="coerce")
numeric_clean = numeric.dropna()
if numeric_clean.empty:
tv_mean = None
base_std = 0.0
else:
mean_val = numeric_clean.mean()
tv_mean = float(mean_val) if pd.notna(mean_val) else None
base_std = float(numeric_clean.std(ddof=1)) if len(numeric_clean) > 1 else 0.0
if not is_cls:
sigma3, std_vals = _compute_sigma3(label_series)
tv_std = std_vals
self.log.info("Label stats: mean=%s std=%s 3σ=%s", tv_mean, tv_std, sigma3)
else:
sigma3 = None
tv_std = base_std
mol_tv_size = [_safe_exact_mol_wt(smi) for smi in tv_df["smiles_clean"]]
mol_tv_size_valid = [mw for mw in mol_tv_size if mw is not None]
if len(mol_tv_size_valid) == 0:
mol_tv_size_mean = None
mol_tv_size_std = None
else:
mol_arr = np.asarray(mol_tv_size_valid, dtype=float)
mol_tv_size_mean = float(mol_arr.mean())
mol_tv_size_std = float(mol_arr.std())
# Features for similarity (TV and Test): molecular + scaffold
fps_tv_mol, fps_tv_scaf = self._featurize_for_similarity(tv_df["smiles_clean"].tolist())
fps_te_mol, fps_te_scaf = self._featurize_for_similarity(test_df["smiles_clean"].tolist())
# Also featurize individual splits for NN similarity reports
fps_tr_mol, _ = self._featurize_for_similarity(train_df["smiles_clean"].tolist())
if valid_df is not None:
fps_va_mol, _ = self._featurize_for_similarity(valid_df["smiles_clean"].tolist())
else:
fps_va_mol = []
# Nearest-neighbor ECFP Tanimoto similarity summaries
sim_valid_to_train = _nn_tanimoto_stats(fps_tr_mol, fps_va_mol) if valid_df is not None else None
sim_test_to_train = _nn_tanimoto_stats(fps_tr_mol, fps_te_mol)
sim_test_to_tv = _nn_tanimoto_stats(fps_tv_mol, fps_te_mol)
# Conflicts (same cleaned SMILES)
intra_train = _intra_conflict_smiles(train_df, is_cls, sigma3)
intra_valid = _intra_conflict_smiles(valid_df, is_cls, sigma3) if valid_df is not None else []
intra_test = _intra_conflict_smiles(test_df, is_cls, sigma3)
cross_tv = _cross_conflict_smiles(train_df, valid_df, is_cls, sigma3) if valid_df is not None else []
cross_tt = _cross_conflict_smiles(train_df, test_df, is_cls, sigma3)
cross_vt = _cross_conflict_smiles(valid_df, test_df, is_cls, sigma3) if valid_df is not None else []
severe_tv_test = _cross_conflict_smiles(tv_df, test_df, is_cls, sigma3)
self.log.info(
"Conflicts: intra_train=%d intra_valid=%d intra_test=%d cross_tv=%d cross_tt=%d cross_vt=%d severe_tv_test=%d",
len(intra_train), len(intra_valid), len(intra_test), len(cross_tv), len(cross_tt), len(cross_vt), len(severe_tv_test)
)
# Cliffs (consensus similar-but-different molecules with label delta)
intra_tv_rows = _cliff_pairs(
tv_df, tv_df,
fps_tv_mol, fps_tv_mol,
fps_tv_scaf, fps_tv_scaf,
self.cfg.sim_threshold, is_cls, sigma3, intra=True
)
if self.cfg.typ != "polaris":
intra_te_rows = _cliff_pairs(
test_df, test_df,
fps_te_mol, fps_te_mol,
fps_te_scaf, fps_te_scaf,
self.cfg.sim_threshold, is_cls, sigma3, intra=True
)
cross_rows = _cliff_pairs(
tv_df, test_df,
fps_tv_mol, fps_te_mol,
fps_tv_scaf, fps_te_scaf,
self.cfg.sim_threshold, is_cls, sigma3, intra=False
)
else:
intra_te_rows = []
cross_rows = []
self.log.info(
"Cliffs: intra_tv=%d intra_te=%d cross=%d",
len(intra_tv_rows), len(intra_te_rows), len(cross_rows)
)
# 7) Aggregate summary
summary: Dict[str, Any] = {
"counts": {
"train": int(len(train_df)),
"valid": int(len(valid_df)) if valid_df is not None else None,
"test": int(len(test_df))
},
"hygiene": {
"n_all_valid_smiles": int(len(all_smiles)),
"n_unique_valid_smiles": int(len(set(all_smiles))),
"n_duplicate_valid_smiles": int(n_dup),
"n_contaminated_tv_vs_test": int(len(contaminated)),
"reos_mean": float(reos_mean),
"reos_std": float(reos_std),
},
"similarity": {
"valid_to_train": (
{"mean": sim_valid_to_train["mean"], "std": sim_valid_to_train["std"], "n": sim_valid_to_train["n"]}
if sim_valid_to_train is not None else None
),
"test_to_train": {
"mean": sim_test_to_train["mean"], "std": sim_test_to_train["std"], "n": sim_test_to_train["n"],
},
"test_to_trainvalid": {
"mean": sim_test_to_tv["mean"], "std": sim_test_to_tv["std"], "n": sim_test_to_tv["n"],
},
},
"task": {
"type": self.cfg.task_type,
"label_tv_mean": _to_python_scalar(tv_mean),
"label_tv_std": _to_python_scalar(tv_std),
"label_tv_3sigma": _to_python_scalar(sigma3) if sigma3 is not None else None,
"mol_tv_size_mean": _to_python_scalar(mol_tv_size_mean),
"mol_tv_size_std": _to_python_scalar(mol_tv_size_std),
},
"conflicts": {
"intra_train": int(len(intra_train)),
"intra_valid": int(len(intra_valid)),
"intra_test": int(len(intra_test)),
"cross_train_valid": int(len(cross_tv)),
"cross_train_test": int(len(cross_tt)),
"cross_valid_test": int(len(cross_vt)),
"severe_trainvalid_test": int(len(severe_tv_test)),
},
"cliffs": {
"intra_train_valid": int(len(intra_tv_rows)),
"intra_test": int(len(intra_te_rows)) if intra_te_rows != [] else None,
"cross_tv_test": int(len(cross_rows)) if cross_rows != [] else None,
"sim_threshold": float(self.cfg.sim_threshold),
},
}
# 8) Per-record table for drill-down (concat in deterministic order)
per_record = pd.concat([train_df, valid_df, test_df], ignore_index=True)
# Conflict rows (detailed)
conflict_rows: List[Dict[str, Any]] = []
conflict_rows += _build_conflict_rows("intra_train", intra_train, train_df)
conflict_rows += _build_conflict_rows("intra_valid", intra_valid, valid_df) if valid_df is not None else []
conflict_rows += _build_conflict_rows("intra_test", intra_test, test_df)
conflict_rows += _build_conflict_rows("cross_train_valid", cross_tv, train_df, valid_df) if valid_df is not None else []
conflict_rows += _build_conflict_rows("cross_train_test", cross_tt, train_df, test_df)
conflict_rows += _build_conflict_rows("cross_valid_test", cross_vt, valid_df, test_df) if valid_df is not None else []
conflict_rows += _build_conflict_rows("severe_trainvalid_test", severe_tv_test, tv_df, test_df)
# Cliff rows are already detailed
cliff_rows: List[Dict[str, Any]] = []
cliff_rows += [{**r, "kind": "intra_tv"} for r in intra_tv_rows]
cliff_rows += [{**r, "kind": "intra_test"} for r in intra_te_rows]
cliff_rows += [{**r, "kind": "cross_tv_test"} for r in cross_rows]
self.log.info("SMILES analysis complete.")
return AnalysisResult(
summary=summary,
per_record_df=per_record,
conflicts_rows=conflict_rows,
cliffs_rows=cliff_rows,
sequence_alignment_rows=None,
)
[docs]
class DTIAnalyzer:
"""Drug–target interaction analysis: combines molecular + sequence hygiene."""
_SEQ_COLUMN_CANDIDATES = (
"sequence_aa",
"sequence",
"target_sequence",
"protein_sequence",
"aa_sequence",
)
def __init__(self, cfg: AnalyzerConfig, logger: Optional[logging.Logger] = None):
self.cfg = cfg
self.log = logger or logging.getLogger(__name__)
self._smiles = SMILESAnalyzer(cfg, self.log)
self._aligner = PSAStretcherAligner()
def _prepare_split(self, split: str, df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()
if "sequence_aa" not in out.columns:
candidates: List[str] = []
if self.cfg.sequence_col:
candidates.append(self.cfg.sequence_col)
candidates.extend([c for c in self._SEQ_COLUMN_CANDIDATES if c not in candidates])
for col in candidates:
if col in out.columns:
out = out.rename(columns={col: "sequence_aa"})
break
if "sequence_aa" not in out.columns:
raise ValueError(
f"Missing amino-acid sequence column for split '{split}'. "
"Provide 'sequence_aa' or set AnalyzerConfig.sequence_col."
)
seq_series = out["sequence_aa"].map(lambda x: "" if pd.isna(x) else str(x))
seq_series = seq_series.str.upper().str.replace(r"\s+", "", regex=True)
out["sequence_aa"] = seq_series
if self.cfg.target_id_col and self.cfg.target_id_col in out.columns:
out = out.rename(columns={self.cfg.target_id_col: "target_id"})
empty_count = int((out["sequence_aa"] == "").sum())
if empty_count > 0:
self.log.warning("Split %s has %d rows with empty target sequences.", split, empty_count)
return out
def _build_drug_summary(self, splits: Dict[str, pd.DataFrame], base_summary: Dict[str, Any]) -> Dict[str, Any]:
train_df = splits["train"]
valid_df = splits.get("valid")
test_df = splits["test"]
unique_counts = {
"train": int(train_df["smiles_clean"].nunique()),
"valid": int(valid_df["smiles_clean"].nunique()) if valid_df is not None else None,
"test": int(test_df["smiles_clean"].nunique()),
}
train_set = set(train_df["smiles_clean"])
valid_set = set(valid_df["smiles_clean"]) if valid_df is not None else set()
test_set = set(test_df["smiles_clean"])
shared_counts = {
"train_valid": int(len(train_set & valid_set)) if valid_df is not None else None,
"train_test": int(len(train_set & test_set)),
"valid_test": int(len(valid_set & test_set)) if valid_df is not None else None,
"trainvalid_test": int(len((train_set | valid_set) & test_set)),
}
hygiene_base = base_summary.get("hygiene", {})
hygiene = {
"n_all_smiles": int(hygiene_base.get("n_all_valid_smiles", 0)),
"n_unique_smiles": int(hygiene_base.get("n_unique_valid_smiles", 0)),
"n_duplicate_smiles": int(hygiene_base.get("n_duplicate_valid_smiles", 0)),
"n_contaminated_trainvalid_vs_test": int(hygiene_base.get("n_contaminated_tv_vs_test", 0)),
}
return {
"unique_counts": unique_counts,
"shared_counts": shared_counts,
"hygiene": hygiene,
"similarity": base_summary.get("similarity", {}),
}
def _analyze_sequences(
self,
splits: Dict[str, pd.DataFrame],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, int]]:
train_df = splits["train"]
valid_df = splits.get("valid")
test_df = splits["test"]
seq_train = set(train_df["sequence_aa"])
seq_valid = set(valid_df["sequence_aa"]) if valid_df is not None else set()
seq_test = set(test_df["sequence_aa"])
seq_train_valid = seq_train | seq_valid
unique_counts = {
"train": int(len(seq_train)),
"valid": int(len(seq_valid)) if valid_df is not None else None,
"test": int(len(seq_test)),
}
shared_counts = {
"train_valid": int(len(seq_train & seq_valid)) if valid_df is not None else None,
"train_test": int(len(seq_train & seq_test)),
"valid_test": int(len(seq_valid & seq_test)) if valid_df is not None else None,
"trainvalid_test": int(len(seq_train_valid & seq_test)),
}
total_sequences = len(train_df) + (len(valid_df) if valid_df is not None else 0) + len(test_df)
all_sequences: List[str] = train_df["sequence_aa"].tolist()
if valid_df is not None:
all_sequences.extend(valid_df["sequence_aa"].tolist())
all_sequences.extend(test_df["sequence_aa"].tolist())
total_unique = len(set(all_sequences))
duplicates_by_split = {
"train": int(len(train_df) - train_df["sequence_aa"].nunique()),
"valid": int(len(valid_df) - valid_df["sequence_aa"].nunique()) if valid_df is not None else None,
"test": int(len(test_df) - test_df["sequence_aa"].nunique()),
}
hygiene = {
"n_all_sequences": int(total_sequences),
"n_unique_sequences": int(total_unique),
"n_duplicate_sequences": int(total_sequences - total_unique),
"n_contaminated_trainvalid_vs_test": int(len(seq_train_valid & seq_test)),
"n_contaminated_train_valid": int(len(seq_train & seq_valid)) if valid_df is not None else None,
"duplicate_sequences_by_split": duplicates_by_split,
}
similarity: Dict[str, Optional[Dict[str, Any]]] = {}
alignment_rows: List[Dict[str, Any]] = []
if valid_df is not None and seq_valid:
stats_vt, details_vt = _nn_sequence_alignment_stats(seq_train, seq_valid, self._aligner, "train", "valid")
similarity["valid_to_train"] = stats_vt
alignment_rows.extend(details_vt)
else:
similarity["valid_to_train"] = None
stats_tt, details_tt = _nn_sequence_alignment_stats(seq_train, seq_test, self._aligner, "train", "test")
similarity["test_to_train"] = stats_tt
alignment_rows.extend(details_tt)
stats_ttv, details_ttv = _nn_sequence_alignment_stats(seq_train_valid, seq_test, self._aligner, "train_valid", "test")
similarity["test_to_trainvalid"] = stats_ttv
alignment_rows.extend(details_ttv)
if alignment_rows:
alignment_rows = sorted(
alignment_rows,
key=lambda r: (r["identity_pct"], r["score"]),
reverse=True,
)[:50]
pair_index: Dict[Tuple[str, str], Dict[str, Any]] = {}
sequence_index: Dict[str, Dict[str, Any]] = defaultdict(lambda: {"splits": set(), "smiles": set(), "rows": []})
for split_name, df in splits.items():
for _, row in df.iterrows():
smi = row["smiles_clean"]
seq = row["sequence_aa"]
pair_entry = pair_index.setdefault((smi, seq), {"splits": set(), "rows": []})
pair_entry["splits"].add(split_name)
pair_entry["rows"].append((split_name, row))
seq_entry = sequence_index[seq]
seq_entry["splits"].add(split_name)
seq_entry["smiles"].add(smi)
seq_entry["rows"].append((split_name, row))
pair_conflicts = {k: v for k, v in pair_index.items() if len(v["splits"]) > 1}
sequence_multi = {
seq: data
for seq, data in sequence_index.items()
if len(data["splits"]) > 1 and len(data["smiles"]) > 1
}
conflict_rows: List[Dict[str, Any]] = []
for (smi, seq), data in pair_conflicts.items():
for split_name, row in data["rows"]:
rec = {
"kind": "dti_pair_cross_split",
"split": split_name,
"id": _to_python_scalar(row["id"]),
"smiles_clean": smi,
"sequence_aa": seq,
"label_raw": _to_python_scalar(row["label_raw"]),
}
if "target_id" in row and not pd.isna(row["target_id"]):
rec["target_id"] = _to_python_scalar(row["target_id"])
conflict_rows.append(rec)
for seq, data in sequence_multi.items():
for split_name, row in data["rows"]:
rec = {
"kind": "sequence_multi_smiles_cross_split",
"split": split_name,
"id": _to_python_scalar(row["id"]),
"smiles_clean": row["smiles_clean"],
"sequence_aa": seq,
"label_raw": _to_python_scalar(row["label_raw"]),
}
if "target_id" in row and not pd.isna(row["target_id"]):
rec["target_id"] = _to_python_scalar(row["target_id"])
conflict_rows.append(rec)
seq_multi_examples = sorted(
(
{
"sequence": seq,
"n_unique_smiles": len(data["smiles"]),
"splits": sorted(data["splits"]),
}
for seq, data in sequence_multi.items()
),
key=lambda item: item["n_unique_smiles"],
reverse=True,
)[:5]
sequence_summary = {
"unique_counts": unique_counts,
"shared_counts": shared_counts,
"hygiene": hygiene,
"similarity": similarity,
"conflicts": {
"cross_split_pairs": int(len(pair_conflicts)),
"sequence_multi_smiles": int(len(sequence_multi)),
},
"examples": {
"sequence_multi_smiles": seq_multi_examples,
},
}
self.log.info(
"Target sequences: unique train=%s valid=%s test=%s shared(train-test)=%d shared(tv-test)=%d pair_conflicts=%d",
unique_counts["train"],
unique_counts["valid"],
unique_counts["test"],
shared_counts["train_test"],
shared_counts["trainvalid_test"],
len(pair_conflicts),
)
conflict_counts = {
"pair_conflicts": int(len(pair_conflicts)),
"sequence_multi_smiles": int(len(sequence_multi)),
}
return sequence_summary, conflict_rows, alignment_rows, conflict_counts
def _analyze_structures_foldseek(
self,
splits: Dict[str, pd.DataFrame],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Structural leakage analysis using Foldseek, mirroring _analyze_sequences.
Semantics:
- valid -> train
- test -> train
- test -> train_valid (train ∪ valid)
Uses nearest-neighbor structural similarity (Foldseek probability).
"""
# ------------------------------------------------------------------
# 1) Guard: skip if Foldseek inputs not provided
# ------------------------------------------------------------------
if not self.cfg.unique_sequences_jsonl or not self.cfg.foldseek_m8_path:
self.log.info("Foldseek inputs not configured; skipping structural analysis.")
return (
{
"similarity": {
"valid_to_train": None,
"test_to_train": None,
"test_to_trainvalid": None,
},
"foldseek_coverage": {},
},
[],
)
# ------------------------------------------------------------------
# 2) Build split sets (EXACTLY like _analyze_sequences)
# ------------------------------------------------------------------
train_df = splits["train"]
valid_df = splits.get("valid")
test_df = splits["test"]
seq_train = set(train_df["sequence_aa"])
seq_valid = set(valid_df["sequence_aa"]) if valid_df is not None else set()
seq_test = set(test_df["sequence_aa"])
seq_train_valid = seq_train | seq_valid
# ------------------------------------------------------------------
# 3) Load Foldseek sequence IDs for THIS dataset
# ------------------------------------------------------------------
seq_to_seqid: Dict[str, str] = {}
seqid_to_seq: Dict[str, str] = {}
dataset_name = self.cfg.name
with open(self.cfg.unique_sequences_jsonl) as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
sid = obj["sequence_id"]
seq = obj["sequence"]
sources = obj.get("sources", []) or []
matched = False
for src in sources:
dataset = src.get("dataset")
if dataset_name and dataset != dataset_name:
continue
if seq not in seq_to_seqid:
seq_to_seqid[seq] = sid
if sid not in seqid_to_seq:
seqid_to_seq[sid] = seq
matched = True
break
if not matched and not dataset_name and seq not in seq_to_seqid:
seq_to_seqid[seq] = sid
if sid not in seqid_to_seq:
seqid_to_seq[sid] = seq
def to_ids(seqs: Set[str]) -> Set[str]:
return {seq_to_seqid[s] for s in seqs if s in seq_to_seqid}
train_ids = to_ids(seq_train)
valid_ids = to_ids(seq_valid)
test_ids = to_ids(seq_test)
train_valid_ids = to_ids(seq_train_valid)
# ------------------------------------------------------------------
# 4) NN search helper (inline, like _nn_sequence_alignment_stats)
# ------------------------------------------------------------------
def _nn_structure_alignment_stats(ref_ids: Set[str], qry_ids: Set[str], ref_split: str, qry_split: str):
best = {}
details = []
with open(self.cfg.foldseek_m8_path) as f:
for line in f:
parts = line.rstrip("\n").split("\t")
if len(parts) < 11:
continue
qid, tid = parts[0], parts[1]
if qid not in qry_ids or tid not in ref_ids:
continue
prob = float(parts[3])
score = float(parts[2])
evalue = float(parts[4])
if (
qid not in best
or prob > best[qid]["probability"]
or (prob == best[qid]["probability"] and score > best[qid]["alignment_score"])
):
best[qid] = {
"split_reference": ref_split,
"split_query": qry_split,
"query_id": qid,
"reference_id": tid,
"probability": prob,
"alignment_score": score,
"evalue": evalue,
"query_start": int(parts[5]),
"query_end": int(parts[6]),
"query_length": int(parts[7]),
"reference_start": int(parts[8]),
"reference_end": int(parts[9]),
"reference_length": int(parts[10]),
}
if not best:
return (
{
"mean_probability": None,
"std_probability": None,
"mean_alignment_score": None,
"std_alignment_score": None,
"mean_evalue": None,
"std_evalue": None,
"n": 0,
},
[],
)
probs = [v["probability"] for v in best.values()]
scores = [v["alignment_score"] for v in best.values()]
evals = [v["evalue"] for v in best.values()]
stats = {
"mean_probability": float(np.mean(probs)),
"std_probability": float(np.std(probs)) if len(probs) > 1 else 0.0,
"mean_alignment_score": float(np.mean(scores)),
"std_alignment_score": float(np.std(scores)) if len(scores) > 1 else 0.0,
"mean_evalue": float(np.mean(evals)),
"std_evalue": float(np.std(evals)) if len(evals) > 1 else 0.0,
"n": int(len(probs)),
}
return stats, list(best.values())
# ------------------------------------------------------------------
# 5) Run NN checks (exactly like sequence)
# ------------------------------------------------------------------
similarity = {}
alignment_rows = []
if valid_ids:
stats_vt, rows_vt = _nn_structure_alignment_stats(train_ids, valid_ids, "train", "valid")
similarity["valid_to_train"] = stats_vt
alignment_rows.extend(rows_vt)
else:
similarity["valid_to_train"] = None
stats_tt, rows_tt = _nn_structure_alignment_stats(train_ids, test_ids, "train", "test")
similarity["test_to_train"] = stats_tt
alignment_rows.extend(rows_tt)
stats_ttv, rows_ttv = _nn_structure_alignment_stats(train_valid_ids, test_ids, "train_valid", "test")
similarity["test_to_trainvalid"] = stats_ttv
alignment_rows.extend(rows_ttv)
# ------------------------------------------------------------------
# 6) Keep top-50 examples (same rule as sequence)
# ------------------------------------------------------------------
if alignment_rows:
ranked = sorted(
alignment_rows,
key=lambda r: (r["probability"], r["alignment_score"]),
reverse=True,
)
filtered: List[Dict[str, Any]] = []
for row in ranked:
q_seq = seqid_to_seq.get(row["query_id"])
r_seq = seqid_to_seq.get(row["reference_id"])
if q_seq is not None and r_seq is not None and q_seq == r_seq:
continue
filtered.append(row)
if len(filtered) >= 50:
break
alignment_rows = filtered
structure_summary = {
"similarity": similarity,
"foldseek_coverage": {
"n_train_sequences_in_foldseek": len(train_ids),
"n_valid_sequences_in_foldseek": len(valid_ids),
"n_test_sequences_in_foldseek": len(test_ids),
"n_train_valid_sequences_in_foldseek": len(train_valid_ids),
},
}
return structure_summary, alignment_rows
[docs]
def run(self, splits_raw: Dict[str, pd.DataFrame]) -> AnalysisResult:
self.log.info("Starting DTI analysis.")
prepared: Dict[str, pd.DataFrame] = {}
for split, df in splits_raw.items():
prepared[split] = self._prepare_split(split, df)
smiles_result = self._smiles.run(prepared)
normalized: Dict[str, pd.DataFrame] = {}
for split, df in prepared.items():
normalized[split] = _normalize_columns(df, self.cfg, split)
sequence_summary, seq_conflict_rows, alignment_rows, conflict_counts = self._analyze_sequences(normalized)
structure_summary, structure_alignment_rows = self._analyze_structures_foldseek(normalized)
summary = copy.deepcopy(smiles_result.summary)
summary["drugs"] = self._build_drug_summary(normalized, summary)
summary["targets"] = sequence_summary
summary["targets"]["structures"] = structure_summary
conflicts_block = summary.setdefault("conflicts", {})
conflicts_block["cross_split_dti_pairs"] = conflict_counts["pair_conflicts"]
conflicts_block["sequence_cross_split_multi_smiles"] = conflict_counts["sequence_multi_smiles"]
combined_conflict_rows = list(smiles_result.conflicts_rows)
combined_conflict_rows.extend(seq_conflict_rows)
self.log.info("DTI analysis complete.")
return AnalysisResult(
summary=summary,
per_record_df=smiles_result.per_record_df,
conflicts_rows=combined_conflict_rows,
cliffs_rows=smiles_result.cliffs_rows,
sequence_alignment_rows=alignment_rows,
structure_alignment_rows=structure_alignment_rows,
)