Source code for mioXpektron.baseline.baseline_base

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import ast
import inspect
import io
import re
import warnings

import numpy as np
import polars as pl
from scipy import signal, ndimage

# pybaselines
try:
    from pybaselines import Baseline
except Exception as e:  # pragma: no cover
    raise ImportError("pybaselines is required. pip install pybaselines") from e

# ---------------------------------------------------------------------
# Universal baseline wrapper (ToF‑SIMS friendly)
# ---------------------------------------------------------------------

[docs] def baseline_method_names() -> List[str]: """Return a sorted list of available baseline algorithms. Based on `pybaselines.Baseline` public callables, plus two custom filters ("median_filter", "adaptive_window") and a 'poly' alias. A few methods that are not 1‑D safe or impractically slow are removed. """ bl = Baseline() skip = {"pentapy_solver", "banded_solver"} methods = { name for name in dir(bl) if (not name.startswith("_") and name not in skip and callable(getattr(bl, name))) } # extras methods.update({"median_filter", "adaptive_window", "poly"}) # remove rarely stable / not applicable methods remove = {"collab_pls", "interp_pts", "cwt_br"} return sorted([m for m in methods if m not in remove])
[docs] def small_param_grid_preset(n_points: Optional[int] = None) -> Dict[str, List[Dict]]: """A compact parameter grid for common methods. Keys must match `pybaselines.Baseline` method names (plus 'poly' and our two filters). Parameters ---------- n_points : int, optional Number of data points in spectrum. If provided, window_size will be calculated adaptively as a percentage of data size. If None, uses moderate defaults suitable for ~100K point spectra. Returns ------- dict Parameter grid with method names as keys Notes ----- Window sizes are calculated as: - Small: 0.05% of data (min 51) - Medium: 0.10% of data (min 101) - Large: 0.20% of data (min 501) This adaptive scaling ensures that filter methods perform consistently across datasets of different sizes. Fixed window sizes work poorly: - For 10K points: window=101 is 1.0% (OK) - For 1M points: window=101 is 0.01% (too small, causes jagged baselines) Examples -------- >>> # Auto-scale for 938K point spectrum >>> grid = small_param_grid_preset(n_points=938000) >>> grid['median_filter'] [{'window_size': 469}, {'window_size': 938}, {'window_size': 1876}] >>> # Use defaults for unknown size >>> grid = small_param_grid_preset() >>> grid['median_filter'] [{'window_size': 501}, {'window_size': 1001}, {'window_size': 2001}] """ # Calculate adaptive window sizes for filter methods if n_points is None: # Default: Conservative values for moderate-sized data (better than old 101/301) window_sizes = [501, 1001, 2001] else: # Adaptive: Scale with data size (0.05%, 0.10%, 0.20%) ws_small = max(51, int(n_points * 0.0005)) # 0.05% of data ws_medium = max(101, int(n_points * 0.001)) # 0.10% of data ws_large = max(501, int(n_points * 0.002)) # 0.20% of data # Ensure odd numbers for symmetric windows ws_small = ws_small if ws_small % 2 == 1 else ws_small + 1 ws_medium = ws_medium if ws_medium % 2 == 1 else ws_medium + 1 ws_large = ws_large if ws_large % 2 == 1 else ws_large + 1 window_sizes = [ws_small, ws_medium, ws_large] return { # Whittaker/penalized-spline family "asls": [{"lam": 1e5, "p": 0.01}, {"lam": 1e6, "p": 0.001}], "iasls": [{"lam": 1e5, "p": 0.01}, {"lam": 1e6, "p": 0.001}], "arpls": [{"lam": 1e6}, {"lam": 1e7}], "drpls": [{"lam": 1e6}, {"lam": 1e7}], "airpls": [{"lam": 1e5}, {"lam": 1e6}], # adaptive iteratively reweighted "aspls": [{"lam": 1e5, "p": 0.01}, {"lam": 1e6, "p": 0.001}], # Polynomial family "modpoly": [{"poly_order": 2}, {"poly_order": 3}, {"poly_order": 4}], "imodpoly": [{"poly_order": 2}, {"poly_order": 3}, {"poly_order": 4}], "poly": [{"poly_order": 2}, {"poly_order": 3}, {"poly_order": 4}], # Simple filters (ADAPTIVE window sizes based on data size) "median_filter": [{"window_size": ws} for ws in window_sizes], "adaptive_window": [{"window_size": ws} for ws in window_sizes], }
def _split_method_and_inline_kwargs(method: str) -> Tuple[str, Dict[str, object]]: """Parse evaluator labels like ``aspls(lam=1000000.0)`` into a method spec.""" text = str(method).strip() match = re.fullmatch(r"(?P<name>[^()]+?)\((?P<args>.*)\)", text) if match is None: return text, {} base_name = match.group("name").strip() args_text = match.group("args").strip() if not args_text: return base_name, {} parsed: Dict[str, object] = {} for item in args_text.split(","): if "=" not in item: return text, {} key, raw_value = item.split("=", 1) key = key.strip() raw_value = raw_value.strip() try: value = ast.literal_eval(raw_value) except (SyntaxError, ValueError): lowered = raw_value.lower() if lowered == "true": value = True elif lowered == "false": value = False elif lowered == "none": value = None else: value = raw_value parsed[key] = value return base_name, parsed
[docs] def baseline_correction( intensities: Union[np.ndarray, List[float]], method: str = "airpls", window_size: int = 101, poly_order: int = 4, clip_negative: bool = True, return_baseline: bool = False, **kwargs, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Baseline-correct a 1‑D spectrum with `pybaselines` or custom filters. Parameters ---------- intensities : array-like Raw y values. method : str Algorithm name; see :func:`baseline_method_names`. window_size : int Kernel width for the two custom filters. poly_order : int Polynomial order for the 'poly' alias. clip_negative : bool If True, negative corrected values are set to 0. return_baseline : bool If True, also return the estimated baseline. **kwargs : Forwarded to the chosen algorithm (e.g. lam=1e6, p=0.01). Returns ------- corrected or (corrected, baseline) """ y = np.asarray(intensities, dtype=float).ravel() if y.ndim != 1: raise ValueError("intensities must be 1‑D") # pragma: no cover parsed_method, inline_kwargs = _split_method_and_inline_kwargs(str(method)) call_overrides = dict(inline_kwargs) call_overrides.update(kwargs) method_lower = parsed_method.lower() poly_like = {"poly", "modpoly", "imodpoly"} needs_rescale = method_lower in poly_like scale = 1.0 y_for_baseline = y if needs_rescale: finite_mask = np.isfinite(y) if finite_mask.any(): max_mag = float(np.max(np.abs(y[finite_mask]))) if max_mag > 0: scale = max_mag y_for_baseline = y / scale # If no finite values are present, leave the original array; downstream # algorithms will fail fast and the caller already handles exceptions. bl = Baseline() _skip = {"pentapy_solver", "banded_solver"} dispatch = {} for name in dir(bl): if name.startswith("_") or name in _skip: continue attr = getattr(bl, name) if callable(attr): dispatch[name] = attr # convenience alias dispatch["poly"] = lambda arr, *, poly_order=poly_order, **k: bl.poly(arr, poly_order=poly_order, **k) # custom filters if method_lower == "median_filter": baseline = signal.medfilt(y, kernel_size=int(call_overrides.get("window_size", window_size))) elif method_lower == "adaptive_window": baseline = ndimage.minimum_filter1d(y, size=int(call_overrides.get("window_size", window_size))) else: func = dispatch.get(method_lower) or dispatch.get(parsed_method) if func is None: raise ValueError(f"Unknown baseline method: {method}") input_y = y_for_baseline if needs_rescale else y call_kwargs = call_overrides if call_overrides: try: sig = inspect.signature(func) except (TypeError, ValueError): # pragma: no cover - builtins without introspection pass else: has_var_kwargs = any( param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() ) if not has_var_kwargs: valid_names = { name for name, param in sig.parameters.items() if param.kind in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) } valid_names.discard("self") valid_names.discard("y") if valid_names: filtered = {k: v for k, v in call_overrides.items() if k in valid_names} else: filtered = {} dropped = set(call_overrides) - set(filtered) if dropped: warnings.warn( f"Ignoring unsupported baseline parameters {sorted(dropped)} for method '{method_lower}'", UserWarning, ) call_kwargs = filtered with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) with np.errstate(divide="ignore", invalid="ignore", over="ignore", under="ignore"): result = func(input_y, **call_kwargs) # pybaselines returns (baseline, params) if isinstance(result, tuple): baseline, _ = result else: # pragma: no cover - defensive against unexpected return types baseline = result if needs_rescale: baseline = baseline * scale if not np.all(np.isfinite(baseline)): raise ValueError("Baseline estimation produced non-finite values") corrected = y - baseline if clip_negative: corrected[corrected < 0] = 0.0 return (corrected, baseline) if return_baseline else corrected
# --------------------------------------------------------------------- # I/O utilities # --------------------------------------------------------------------- COL_ALIASES = { "channel": {"channel", "chan", "ch", "index", "idx"}, "mz": {"m/z", "mz", "mass", "moverz", "m_over_z"}, "intensity": {"intensity", "counts", "signal", "y", "ion_counts"}, } _WHITESPACE_SPLIT = re.compile(r"\s+") def _standardize_columns(df: pl.DataFrame) -> pl.DataFrame: rename = {} lower = {c: str(c).strip().lower() for c in df.columns} for std, aliases in COL_ALIASES.items(): for original, lowered in lower.items(): if lowered in aliases: rename[original] = std break if rename: df = df.rename(rename) # Sanity checks & fallbacks. Channel is optional; fabricate if missing to preserve order. if "mz" not in df.columns or "intensity" not in df.columns: raise KeyError("Input table must include 'm/z' (or alias) and 'intensity' columns.") if "channel" not in df.columns: df = df.with_row_index(name="channel", offset=1) df = df.select(["channel", "mz", "intensity"]) df = df.with_columns( pl.col("channel").cast(pl.Int64), pl.col("mz").cast(pl.Float64), pl.col("intensity").cast(pl.Float64), ) return df def _read_with_separator(path: Path, sep: str) -> Optional[pl.DataFrame]: try: df = pl.read_csv( path, separator=sep, comment_prefix="#", infer_schema_length=4096, ignore_errors=False, ) except Exception: return None return df if df.width >= 2 else None def _read_whitespace_table(path: Path) -> pl.DataFrame: lines: List[str] = [] with path.open("r", encoding="utf-8", errors="ignore") as handle: for raw in handle: stripped = raw.strip() if not stripped or stripped.startswith("#"): continue lines.append(_WHITESPACE_SPLIT.sub(",", stripped)) if not lines: raise ValueError(f"No tabular data found in {path}.") buffer = io.StringIO("\n".join(lines)) return pl.read_csv(buffer, separator=",")
[docs] def read_spectrum_table(path: Union[str, Path]) -> pl.DataFrame: """Read a ToF‑SIMS table from CSV/TSV/TXT with flexible separators.""" path = Path(path) df: Optional[pl.DataFrame] = None for sep in (",", "\t", ";"): df = _read_with_separator(path, sep) if df is not None: break if df is None: df = _read_whitespace_table(path) df = _standardize_columns(df) df = df.sort("channel") return df
# --------------------------------------------------------------------- # Metrics # --------------------------------------------------------------------- def _noise_mask_from_quantile(y_raw: np.ndarray, q: float = 0.2) -> np.ndarray: """Boolean mask selecting the lowest-q quantile as baseline-only region.""" finite = np.isfinite(y_raw) if not finite.any(): raise ValueError("No finite values in spectrum.") thresh = np.nanquantile(y_raw[finite], q) return (y_raw <= thresh) & finite
[docs] @dataclass class MetricResult: rfzn: float # residual flat-zone noise (RMS) nar: float # negative area ratio snr: float # median SNR of top-K peaks bbi: float # baseline bias index (median y_corr in baseline zones) br: float # baseline roughness (RMS of baseline second derivative in baseline zones) nbc: float # negative bin fraction
[docs] def compute_metrics( y_corr: np.ndarray, y_raw: np.ndarray, baseline: Optional[np.ndarray], x: Optional[np.ndarray], noise_mask: Optional[np.ndarray] = None, topk: int = 5, raw_noise_quantile: float = 0.2, ) -> MetricResult: """Compute RFZN, NAR, SNR, BBI, BR, NBC for a single corrected spectrum. Notes ----- * RFZN: RMS of y_corr in baseline-only region (noise_mask). If mask is not supplied, it is derived from the **raw** intensities (bottom-q). * NAR: sum(-y_corr[y<0]) / sum(|y_corr|); lower is better. * SNR: median prominence of top-K peaks divided by noise sigma. * BBI: median(y_corr[noise_mask]); lower magnitude is better. * BR: RMS of d²(baseline)/dx² in baseline-only regions; requires `baseline` and `x`. * NBC: fraction of points where y_corr < 0 (before clipping). """ y_corr = np.asarray(y_corr, dtype=float).ravel() y_raw = np.asarray(y_raw, dtype=float).ravel() assert y_corr.shape == y_raw.shape if noise_mask is None: noise_mask = _noise_mask_from_quantile(y_raw, raw_noise_quantile) # RFZN sigma_noise = float(np.sqrt(np.mean(y_corr[noise_mask] ** 2))) if noise_mask.any() else float("nan") # NAR denom = float(np.sum(np.abs(y_corr))) or np.nan neg_area = float(np.sum(-y_corr[y_corr < 0.0])) nar = neg_area / denom if denom and denom > 0 else float("nan") # lower is better # SNR via prominent peaks sigma_raw = float(np.sqrt(np.mean((y_raw[noise_mask]) ** 2))) if noise_mask.any() else 0.0 prom_thr = max(3.0 * sigma_raw, 0.0) peaks, props = signal.find_peaks(y_corr, prominence=prom_thr) if peaks.size == 0: peak_heights = np.array([np.nanmax(y_corr)]) else: peak_heights = np.asarray(props.get("prominences", y_corr[peaks]), dtype=float) if peak_heights.size > 1: top = np.sort(peak_heights)[-min(topk, peak_heights.size):] peak_stat = float(np.median(top)) else: peak_stat = float(peak_heights.item()) snr = peak_stat / sigma_noise if (sigma_noise and sigma_noise > 0) else float("nan") # BBI bbi = float(np.nanmedian(y_corr[noise_mask])) if noise_mask.any() else float("nan") # BR (needs baseline & x) if baseline is None or x is None: br = float("nan") else: x = np.asarray(x, dtype=float).ravel() bl = np.asarray(baseline, dtype=float).ravel() if x.shape != bl.shape: br = float("nan") else: # second derivative w.r.t. x (handles nonuniform spacing) d1 = np.gradient(bl, x, edge_order=2) d2 = np.gradient(d1, x, edge_order=2) br = float(np.sqrt(np.mean((d2[noise_mask]) ** 2))) if noise_mask.any() else float("nan") # NBC nbc = float(np.mean(y_corr < 0.0)) return MetricResult(rfzn=sigma_noise, nar=nar, snr=snr, bbi=bbi, br=br, nbc=nbc)