Source code for pykinbiont.types

"""Public type definitions for pykinbiont."""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import Any, Iterator, Optional, Union
import numpy as np
import pandas as pd


[docs] @dataclass(frozen=True) class GrowthData: """Container for growth curves at common time points. Parameters ---------- curves: Shape (n_curves, n_timepoints), float64. Each row is one curve. times: Shape (n_timepoints,), float64. Shared time grid. labels: Identifier per curve, length n_curves. clusters: Cluster assignment per curve (1-based), shape (n_curves,). Populated by preprocess() when cluster=True. None until then. centroids: Per-cluster shape centroids in z-normalised space, shape (n_clusters, n_timepoints). Populated alongside clusters. wcss: Within-cluster sum of squares from k-means. None until clustering runs. """ curves: np.ndarray times: np.ndarray labels: list[str] clusters: Optional[np.ndarray] = None centroids: Optional[np.ndarray] = None wcss: Optional[float] = None def __post_init__(self) -> None: if self.curves.ndim != 2: raise ValueError(f"curves must be 2-D, got shape {self.curves.shape}") n_curves, n_tp = self.curves.shape if len(self.times) != n_tp: raise ValueError( f"times length {len(self.times)} != n_timepoints {n_tp}" ) if len(self.labels) != n_curves: raise ValueError( f"labels length {len(self.labels)} != n_curves {n_curves}" ) if self.clusters is not None and len(self.clusters) != n_curves: raise ValueError( f"clusters length {len(self.clusters)} != n_curves {n_curves}" ) self.curves.flags["WRITEABLE"] = False self.times.flags["WRITEABLE"] = False
[docs] @staticmethod def from_csv(path: str) -> "GrowthData": """Load from CSV. First column = times; remaining columns = curves.""" df = pd.read_csv(path) return GrowthData.from_dataframe(df)
[docs] @staticmethod def from_dataframe(df: pd.DataFrame) -> "GrowthData": """Load from DataFrame. First column = times; remaining = curves.""" times = df.iloc[:, 0].to_numpy(dtype=np.float64) labels = list(df.columns[1:]) curves = df.iloc[:, 1:].to_numpy(dtype=np.float64).T # (n_curves, n_tp) return GrowthData(curves=curves, times=times, labels=labels)
[docs] def __getitem__(self, labels: list[str]) -> "GrowthData": """Return a new GrowthData with only the requested curves.""" missing = [lb for lb in labels if lb not in self.labels] if missing: raise KeyError(f"labels not found in GrowthData: {missing}") idx = [self.labels.index(lb) for lb in labels] return GrowthData( curves=self.curves[idx, :], times=self.times, labels=list(labels), )
def _normalize_01(times: np.ndarray) -> np.ndarray: tmin, tmax = float(times.min()), float(times.max()) span = tmax - tmin return np.zeros_like(times) if span <= 0.0 else (times - tmin) / span def _build_union_grid(times01_list: list[np.ndarray], step: float) -> np.ndarray: points: set[float] = set() for t in times01_list: for v in t: snapped = round(float(v) / step) * step points.add(max(0.0, min(1.0, snapped))) grid = sorted(points) if not grid or grid[0] > 0.0: grid = [0.0] + grid if not grid or grid[-1] < 1.0: grid = grid + [1.0] return np.array(grid, dtype=np.float64)
[docs] class IrregularGrowthData: """Growth curves with per-curve irregular time points. Resampling to a shared [0,1] union grid is performed automatically at construction time (pure Python / numpy, no Julia required). Parameters ---------- raw_curves: Original OD values, one 1-D float64 array per curve. raw_times: Original (un-normalised) time points, one 1-D float64 array per curve. labels: Identifier per curve. curves: Resampled matrix on the [0,1] union grid, shape (n_curves, n_grid). Set automatically — do not pass manually. times: The [0,1] union grid, shape (n_grid,). Set automatically — do not pass manually. step: Union grid resolution (default 0.01). clusters, centroids, wcss: Populated by preprocess() when cluster=True. """ def __init__( self, raw_curves: list[np.ndarray], raw_times: list[np.ndarray], labels: list[str], curves: Optional[np.ndarray] = None, times: Optional[np.ndarray] = None, step: float = 0.01, clusters: Optional[np.ndarray] = None, centroids: Optional[np.ndarray] = None, wcss: Optional[float] = None, ) -> None: n = len(raw_curves) if len(raw_times) != n: raise ValueError( f"raw_times length {len(raw_times)} != n_curves {n}" ) if len(labels) != n: raise ValueError( f"labels length {len(labels)} != n_curves {n}" ) for i, (rc, rt) in enumerate(zip(raw_curves, raw_times)): if len(rc) != len(rt): raise ValueError( f"curve {i}: raw_curves and raw_times must have the same length" ) if len(rt) < 2: raise ValueError(f"curve {i}: time vector must have at least 2 points") # Resample to union grid (pure Python, no Julia) if curves is None or times is None: times01_list = [_normalize_01(np.asarray(rt, dtype=np.float64)) for rt in raw_times] union_grid = _build_union_grid(times01_list, step) curves_mat = np.stack([ np.interp(union_grid, t01, np.asarray(rc, dtype=np.float64)) for t01, rc in zip(times01_list, raw_curves) ]) times = union_grid curves = curves_mat self.raw_curves: list[np.ndarray] = list(raw_curves) self.raw_times: list[np.ndarray] = list(raw_times) self.labels: list[str] = list(labels) self.curves: np.ndarray = curves self.times: np.ndarray = times self.step: float = step self.clusters: Optional[np.ndarray] = clusters self.centroids: Optional[np.ndarray] = centroids self.wcss: Optional[float] = wcss
[docs] def __getitem__(self, labels: list[str]) -> "IrregularGrowthData": missing = [lb for lb in labels if lb not in self.labels] if missing: raise KeyError(f"labels not found in IrregularGrowthData: {missing}") idx = [self.labels.index(lb) for lb in labels] return IrregularGrowthData( raw_curves=[self.raw_curves[i] for i in idx], raw_times=[self.raw_times[i] for i in idx], labels=list(labels), curves=self.curves[idx, :], times=self.times, step=self.step, )
_SMOOTH_METHODS = {"lowess", "rolling_avg", "gaussian", "boxcar", "none"} _NEG_METHODS = {"remove", "thr_correction", "blank_correction"} _SCATTER_METHODS = {"interpolation", "exp_fit"}
[docs] @dataclass class FitOptions: """All configuration for preprocessing and fitting. Every field mirrors the Julia FitOptions struct exactly — the Julia docs serve as the authoritative reference for field semantics. """ # --- preprocessing --- smooth: bool = False smooth_method: str = "lowess" smooth_pt_avg: int = 7 boxcar_window: int = 5 lowess_frac: float = 0.05 gaussian_h_mult: float = 2.0 gaussian_time_grid: Optional[np.ndarray] = None average_replicates: bool = False blank_subtraction: bool = False blank_value: float = 0.0 blank_from_labels: bool = False correct_negatives: bool = False negative_method: str = "remove" negative_threshold: float = 0.01 scattering_correction: bool = False calibration_file: str = "" scattering_method: str = "interpolation" # --- stationary phase --- cut_stationary_phase: bool = False stationary_percentile_thr: float = 0.05 stationary_pt_smooth_derivative: int = 10 stationary_win_size: int = 5 stationary_thr_od: float = 0.02 # --- clustering --- cluster: bool = False n_clusters: int = 3 cluster_trend_test: bool = True cluster_prescreen_constant: bool = False cluster_tol_const: float = 1.5 cluster_q_low: float = 0.05 cluster_q_high: float = 0.95 cluster_exp_prototype: bool = False kmeans_n_init: int = 10 kmeans_max_iters: int = 300 kmeans_tol: float = 1e-6 kmeans_seed: int = 0 # --- fitting --- loss: str = "RE" multistart: bool = False n_restart: int = 50 aic_correction: bool = True pt_smooth_derivative: int = 7 opt_params: dict = field(default_factory=dict) def __post_init__(self) -> None: if self.smooth_method not in _SMOOTH_METHODS: raise ValueError( f"smooth_method must be one of {_SMOOTH_METHODS}, got {self.smooth_method!r}" ) if self.negative_method not in _NEG_METHODS: raise ValueError( f"negative_method must be one of {_NEG_METHODS}, got {self.negative_method!r}" ) if self.scattering_method not in _SCATTER_METHODS: raise ValueError( f"scattering_method must be one of {_SCATTER_METHODS}, got {self.scattering_method!r}" ) if self.cluster and self.n_clusters < 2: raise ValueError( f"n_clusters must be >= 2 when cluster=True, got {self.n_clusters}" ) if self.scattering_correction and not self.calibration_file: raise ValueError( "calibration_file must be non-empty when scattering_correction=True" ) _KEY_RE = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') _ALLOWED_TYPES = (int, float, bool) for k, v in self.opt_params.items(): if not isinstance(k, str) or not _KEY_RE.match(k): raise ValueError( f"opt_params key {k!r} is not a valid Julia identifier" ) if not isinstance(v, _ALLOWED_TYPES): raise ValueError( f"opt_params[{k!r}] has type {type(v).__name__}; " f"only int, float, bool are allowed" )
[docs] @dataclass class ModelSpec: """Which models to fit and their initial parameters. Parameters ---------- models: List of AbstractGrowthModel instances. params: Initial parameter guess per model (empty list [] for LogLinModel/DDDEModel). lower: Per-model lower bounds; None slot means unconstrained for that model. upper: Per-model upper bounds; None slot means unconstrained for that model. """ models: list[Any] params: list[list[float]] lower: Optional[list[Optional[list[float]]]] = None upper: Optional[list[Optional[list[float]]]] = None def __post_init__(self) -> None: if len(self.models) != len(self.params): raise ValueError( f"models and params must have the same length, " f"got {len(self.models)} and {len(self.params)}" ) if self.lower is not None and len(self.lower) != len(self.models): raise ValueError( f"lower bounds length {len(self.lower)} != models length {len(self.models)}" ) if self.upper is not None and len(self.upper) != len(self.models): raise ValueError( f"upper bounds length {len(self.upper)} != models length {len(self.models)}" )
[docs] @dataclass(frozen=True) class CurveFitResult: """Fitting result for a single growth curve.""" label: str best_model: str best_params: list[float] param_names: list[str] best_aic: float fitted_curve: np.ndarray times: np.ndarray loss: float all_results: list[dict]
[docs] @dataclass(frozen=True) class GrowthFitResults: """Top-level result returned by fit().""" data: Union[GrowthData, IrregularGrowthData] results: list[CurveFitResult]
[docs] def to_dataframe(self) -> pd.DataFrame: """Summary table: one row per curve with best model, AIC, params.""" n_max = max((len(r.best_params) for r in self.results), default=0) rows = [] for r in self.results: idx = self.data.labels.index(r.label) cluster = ( int(self.data.clusters[idx]) if self.data.clusters is not None else None ) row: dict = { "label": r.label, "cluster": cluster, "best_model": r.best_model, "aic": r.best_aic, "loss": r.loss, } for k in range(n_max): row[f"param_{k + 1}"] = ( r.best_params[k] if k < len(r.best_params) else None ) rows.append(row) return pd.DataFrame(rows)
[docs] def __iter__(self) -> "Iterator[CurveFitResult]": return iter(self.results)
[docs] def __len__(self) -> int: return len(self.results)
[docs] def __getitem__(self, i: int) -> CurveFitResult: if not isinstance(i, int): raise TypeError(f"index must be int, got {type(i).__name__!r}") return self.results[i]