"""Public type definitions for pykinbiont."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Iterator, Optional, Union
import numpy as np
import pandas as pd
[docs]
@dataclass(frozen=True)
class GrowthData:
"""Container for growth curves at common time points.
Parameters
----------
curves:
Shape (n_curves, n_timepoints), float64. Each row is one curve.
times:
Shape (n_timepoints,), float64. Shared time grid.
labels:
Identifier per curve, length n_curves.
clusters:
Cluster assignment per curve (1-based), shape (n_curves,). Populated
by preprocess() when cluster=True. None until then.
centroids:
Per-cluster shape centroids in z-normalised space, shape
(n_clusters, n_timepoints). Populated alongside clusters.
wcss:
Within-cluster sum of squares from k-means. None until clustering runs.
"""
curves: np.ndarray
times: np.ndarray
labels: list[str]
clusters: Optional[np.ndarray] = None
centroids: Optional[np.ndarray] = None
wcss: Optional[float] = None
def __post_init__(self) -> None:
if self.curves.ndim != 2:
raise ValueError(f"curves must be 2-D, got shape {self.curves.shape}")
n_curves, n_tp = self.curves.shape
if len(self.times) != n_tp:
raise ValueError(
f"times length {len(self.times)} != n_timepoints {n_tp}"
)
if len(self.labels) != n_curves:
raise ValueError(
f"labels length {len(self.labels)} != n_curves {n_curves}"
)
if self.clusters is not None and len(self.clusters) != n_curves:
raise ValueError(
f"clusters length {len(self.clusters)} != n_curves {n_curves}"
)
self.curves.flags["WRITEABLE"] = False
self.times.flags["WRITEABLE"] = False
[docs]
@staticmethod
def from_csv(path: str) -> "GrowthData":
"""Load from CSV. First column = times; remaining columns = curves."""
df = pd.read_csv(path)
return GrowthData.from_dataframe(df)
[docs]
@staticmethod
def from_dataframe(df: pd.DataFrame) -> "GrowthData":
"""Load from DataFrame. First column = times; remaining = curves."""
times = df.iloc[:, 0].to_numpy(dtype=np.float64)
labels = list(df.columns[1:])
curves = df.iloc[:, 1:].to_numpy(dtype=np.float64).T # (n_curves, n_tp)
return GrowthData(curves=curves, times=times, labels=labels)
[docs]
def __getitem__(self, labels: list[str]) -> "GrowthData":
"""Return a new GrowthData with only the requested curves."""
missing = [lb for lb in labels if lb not in self.labels]
if missing:
raise KeyError(f"labels not found in GrowthData: {missing}")
idx = [self.labels.index(lb) for lb in labels]
return GrowthData(
curves=self.curves[idx, :],
times=self.times,
labels=list(labels),
)
def _normalize_01(times: np.ndarray) -> np.ndarray:
tmin, tmax = float(times.min()), float(times.max())
span = tmax - tmin
return np.zeros_like(times) if span <= 0.0 else (times - tmin) / span
def _build_union_grid(times01_list: list[np.ndarray], step: float) -> np.ndarray:
points: set[float] = set()
for t in times01_list:
for v in t:
snapped = round(float(v) / step) * step
points.add(max(0.0, min(1.0, snapped)))
grid = sorted(points)
if not grid or grid[0] > 0.0:
grid = [0.0] + grid
if not grid or grid[-1] < 1.0:
grid = grid + [1.0]
return np.array(grid, dtype=np.float64)
[docs]
class IrregularGrowthData:
"""Growth curves with per-curve irregular time points.
Resampling to a shared [0,1] union grid is performed automatically at
construction time (pure Python / numpy, no Julia required).
Parameters
----------
raw_curves:
Original OD values, one 1-D float64 array per curve.
raw_times:
Original (un-normalised) time points, one 1-D float64 array per curve.
labels:
Identifier per curve.
curves:
Resampled matrix on the [0,1] union grid, shape (n_curves, n_grid).
Set automatically — do not pass manually.
times:
The [0,1] union grid, shape (n_grid,).
Set automatically — do not pass manually.
step:
Union grid resolution (default 0.01).
clusters, centroids, wcss:
Populated by preprocess() when cluster=True.
"""
def __init__(
self,
raw_curves: list[np.ndarray],
raw_times: list[np.ndarray],
labels: list[str],
curves: Optional[np.ndarray] = None,
times: Optional[np.ndarray] = None,
step: float = 0.01,
clusters: Optional[np.ndarray] = None,
centroids: Optional[np.ndarray] = None,
wcss: Optional[float] = None,
) -> None:
n = len(raw_curves)
if len(raw_times) != n:
raise ValueError(
f"raw_times length {len(raw_times)} != n_curves {n}"
)
if len(labels) != n:
raise ValueError(
f"labels length {len(labels)} != n_curves {n}"
)
for i, (rc, rt) in enumerate(zip(raw_curves, raw_times)):
if len(rc) != len(rt):
raise ValueError(
f"curve {i}: raw_curves and raw_times must have the same length"
)
if len(rt) < 2:
raise ValueError(f"curve {i}: time vector must have at least 2 points")
# Resample to union grid (pure Python, no Julia)
if curves is None or times is None:
times01_list = [_normalize_01(np.asarray(rt, dtype=np.float64)) for rt in raw_times]
union_grid = _build_union_grid(times01_list, step)
curves_mat = np.stack([
np.interp(union_grid, t01, np.asarray(rc, dtype=np.float64))
for t01, rc in zip(times01_list, raw_curves)
])
times = union_grid
curves = curves_mat
self.raw_curves: list[np.ndarray] = list(raw_curves)
self.raw_times: list[np.ndarray] = list(raw_times)
self.labels: list[str] = list(labels)
self.curves: np.ndarray = curves
self.times: np.ndarray = times
self.step: float = step
self.clusters: Optional[np.ndarray] = clusters
self.centroids: Optional[np.ndarray] = centroids
self.wcss: Optional[float] = wcss
[docs]
def __getitem__(self, labels: list[str]) -> "IrregularGrowthData":
missing = [lb for lb in labels if lb not in self.labels]
if missing:
raise KeyError(f"labels not found in IrregularGrowthData: {missing}")
idx = [self.labels.index(lb) for lb in labels]
return IrregularGrowthData(
raw_curves=[self.raw_curves[i] for i in idx],
raw_times=[self.raw_times[i] for i in idx],
labels=list(labels),
curves=self.curves[idx, :],
times=self.times,
step=self.step,
)
_SMOOTH_METHODS = {"lowess", "rolling_avg", "gaussian", "boxcar", "none"}
_NEG_METHODS = {"remove", "thr_correction", "blank_correction"}
_SCATTER_METHODS = {"interpolation", "exp_fit"}
[docs]
@dataclass
class FitOptions:
"""All configuration for preprocessing and fitting.
Every field mirrors the Julia FitOptions struct exactly — the Julia docs
serve as the authoritative reference for field semantics.
"""
# --- preprocessing ---
smooth: bool = False
smooth_method: str = "lowess"
smooth_pt_avg: int = 7
boxcar_window: int = 5
lowess_frac: float = 0.05
gaussian_h_mult: float = 2.0
gaussian_time_grid: Optional[np.ndarray] = None
average_replicates: bool = False
blank_subtraction: bool = False
blank_value: float = 0.0
blank_from_labels: bool = False
correct_negatives: bool = False
negative_method: str = "remove"
negative_threshold: float = 0.01
scattering_correction: bool = False
calibration_file: str = ""
scattering_method: str = "interpolation"
# --- stationary phase ---
cut_stationary_phase: bool = False
stationary_percentile_thr: float = 0.05
stationary_pt_smooth_derivative: int = 10
stationary_win_size: int = 5
stationary_thr_od: float = 0.02
# --- clustering ---
cluster: bool = False
n_clusters: int = 3
cluster_trend_test: bool = True
cluster_prescreen_constant: bool = False
cluster_tol_const: float = 1.5
cluster_q_low: float = 0.05
cluster_q_high: float = 0.95
cluster_exp_prototype: bool = False
kmeans_n_init: int = 10
kmeans_max_iters: int = 300
kmeans_tol: float = 1e-6
kmeans_seed: int = 0
# --- fitting ---
loss: str = "RE"
multistart: bool = False
n_restart: int = 50
aic_correction: bool = True
pt_smooth_derivative: int = 7
opt_params: dict = field(default_factory=dict)
def __post_init__(self) -> None:
if self.smooth_method not in _SMOOTH_METHODS:
raise ValueError(
f"smooth_method must be one of {_SMOOTH_METHODS}, got {self.smooth_method!r}"
)
if self.negative_method not in _NEG_METHODS:
raise ValueError(
f"negative_method must be one of {_NEG_METHODS}, got {self.negative_method!r}"
)
if self.scattering_method not in _SCATTER_METHODS:
raise ValueError(
f"scattering_method must be one of {_SCATTER_METHODS}, got {self.scattering_method!r}"
)
if self.cluster and self.n_clusters < 2:
raise ValueError(
f"n_clusters must be >= 2 when cluster=True, got {self.n_clusters}"
)
if self.scattering_correction and not self.calibration_file:
raise ValueError(
"calibration_file must be non-empty when scattering_correction=True"
)
_KEY_RE = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
_ALLOWED_TYPES = (int, float, bool)
for k, v in self.opt_params.items():
if not isinstance(k, str) or not _KEY_RE.match(k):
raise ValueError(
f"opt_params key {k!r} is not a valid Julia identifier"
)
if not isinstance(v, _ALLOWED_TYPES):
raise ValueError(
f"opt_params[{k!r}] has type {type(v).__name__}; "
f"only int, float, bool are allowed"
)
[docs]
@dataclass
class ModelSpec:
"""Which models to fit and their initial parameters.
Parameters
----------
models:
List of AbstractGrowthModel instances.
params:
Initial parameter guess per model (empty list [] for LogLinModel/DDDEModel).
lower:
Per-model lower bounds; None slot means unconstrained for that model.
upper:
Per-model upper bounds; None slot means unconstrained for that model.
"""
models: list[Any]
params: list[list[float]]
lower: Optional[list[Optional[list[float]]]] = None
upper: Optional[list[Optional[list[float]]]] = None
def __post_init__(self) -> None:
if len(self.models) != len(self.params):
raise ValueError(
f"models and params must have the same length, "
f"got {len(self.models)} and {len(self.params)}"
)
if self.lower is not None and len(self.lower) != len(self.models):
raise ValueError(
f"lower bounds length {len(self.lower)} != models length {len(self.models)}"
)
if self.upper is not None and len(self.upper) != len(self.models):
raise ValueError(
f"upper bounds length {len(self.upper)} != models length {len(self.models)}"
)
[docs]
@dataclass(frozen=True)
class CurveFitResult:
"""Fitting result for a single growth curve."""
label: str
best_model: str
best_params: list[float]
param_names: list[str]
best_aic: float
fitted_curve: np.ndarray
times: np.ndarray
loss: float
all_results: list[dict]
[docs]
@dataclass(frozen=True)
class GrowthFitResults:
"""Top-level result returned by fit()."""
data: Union[GrowthData, IrregularGrowthData]
results: list[CurveFitResult]
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""Summary table: one row per curve with best model, AIC, params."""
n_max = max((len(r.best_params) for r in self.results), default=0)
rows = []
for r in self.results:
idx = self.data.labels.index(r.label)
cluster = (
int(self.data.clusters[idx])
if self.data.clusters is not None
else None
)
row: dict = {
"label": r.label,
"cluster": cluster,
"best_model": r.best_model,
"aic": r.best_aic,
"loss": r.loss,
}
for k in range(n_max):
row[f"param_{k + 1}"] = (
r.best_params[k] if k < len(r.best_params) else None
)
rows.append(row)
return pd.DataFrame(rows)
[docs]
def __iter__(self) -> "Iterator[CurveFitResult]":
return iter(self.results)
[docs]
def __len__(self) -> int:
return len(self.results)
[docs]
def __getitem__(self, i: int) -> CurveFitResult:
if not isinstance(i, int):
raise TypeError(f"index must be int, got {type(i).__name__!r}")
return self.results[i]