Source code for utils.rank_fragility.config
"""Configuration dataclasses for rank-fragility analysis."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
Task = Literal["classification", "regression"]
[docs]
@dataclass(frozen=True)
class AuditConfig:
"""Column names and thresholds used to annotate audited molecules."""
id_col: str = "molecule_id"
smiles_col: str = "smiles"
label_col: str = "y"
split_col: str = "split"
task: Task = "classification"
near_leak_thresholds: tuple[float, ...] = (0.85, 0.90)
primary_near_leak_threshold: float = 0.85
regression_conflict_threshold: float = 1.0
regression_conflict_threshold_sensitivity: float | None = None
random_seed: int = 13
[docs]
@dataclass(frozen=True)
class PanelConfig:
"""Sampling controls for generated counterfactual evaluation panels."""
id_col: str = "molecule_id"
label_col: str = "y"
task: Task = "classification"
panel_size: int | str = "auto"
n_panels: int = 1000
target_rates: tuple[float | str, ...] = (0.0, 0.05, 0.10, 0.25, "observed", 0.50, 0.75)
random_seed: int = 13
output_dir: Path | str = Path("runs/rank_fragility")
[docs]
@dataclass(frozen=True)
class MetricConfig:
"""Metric and model-selection settings for leaderboard comparisons."""
task: Task = "classification"
metric: str = "auroc"
baseline_model: str = "ecfp_rf"
sota_model: str = "auto"
[docs]
@dataclass(frozen=True)
class RunConfig:
"""Complete input, audit, panel, and output settings for one analysis run."""
data: Path
pred_dir: Path
id_col: str = "molecule_id"
smiles_col: str = "smiles"
label_col: str = "y"
split_col: str = "split"
task: Task = "classification"
metric: str = "auroc"
near_leak_thresholds: tuple[float, ...] = (0.85, 0.90)
primary_near_leak_threshold: float = 0.85
regression_conflict_threshold: float = 1.0
regression_conflict_threshold_sensitivity: float | None = None
random_seed: int = 13
panel_size: int | str = "auto"
n_panels: int = 1000
target_rates: tuple[float | str, ...] = field(
default_factory=lambda: (0.0, 0.05, 0.10, 0.25, "observed", 0.50, 0.75)
)
baseline_model: str = "ecfp_rf"
sota_model: str = "auto"
output_dir: Path = Path("runs/rank_fragility")
[docs]
def audit_config(self) -> AuditConfig:
"""Return the audit-specific subset of this run configuration."""
return AuditConfig(
id_col=self.id_col,
smiles_col=self.smiles_col,
label_col=self.label_col,
split_col=self.split_col,
task=self.task,
near_leak_thresholds=self.near_leak_thresholds,
primary_near_leak_threshold=self.primary_near_leak_threshold,
regression_conflict_threshold=self.regression_conflict_threshold,
regression_conflict_threshold_sensitivity=self.regression_conflict_threshold_sensitivity,
random_seed=self.random_seed,
)
[docs]
def panel_config(self) -> PanelConfig:
"""Return the panel-sampling subset of this run configuration."""
return PanelConfig(
id_col=self.id_col,
label_col=self.label_col,
task=self.task,
panel_size=self.panel_size,
n_panels=self.n_panels,
target_rates=self.target_rates,
random_seed=self.random_seed,
output_dir=self.output_dir,
)
[docs]
def metric_config(self) -> MetricConfig:
"""Return the metric/model subset of this run configuration."""
return MetricConfig(
task=self.task,
metric=self.metric,
baseline_model=self.baseline_model,
sota_model=self.sota_model,
)