"""
flat_window_suggester_polars.py
--------------------------------
Small application to discover common "flat" m/z windows across a set of
ToF‑SIMS spectra. Inputs are 3‑column tables: Channel, m/z, intensity
(case‑insensitive).
Key changes vs. the original:
- Replaced all pandas operations with Polars (Rust/Arrow backend).
- Added support for providing an explicit *list of file paths or glob patterns*,
so data can be spread across many folders (no need for a single root dir).
- Kept the numerical core (NumPy/SciPy) for smoothing & derivatives.
What it does (unchanged conceptually)
-------------------------------------
1) Per spectrum:
- Smooth intensities (Savitzky–Golay) and compute 1st/2nd derivatives.
- Flag baseline‑candidate points where simultaneously:
y_raw <= q_y quantile AND
|dy/dx| <= q_g quantile AND
|d²y/dx²| <= q_c quantile
- Merge contiguous candidate points into segments; keep segments that
satisfy minimum width & minimum number of points.
2) Across all spectra:
- Discretize the global m/z range into bins (width = bin_width).
- For each file, mark bins covered by any of its segments.
- Compute the coverage fraction per bin (#files covering bin / #files total).
- Extract contiguous regions whose coverage ≥ coverage_threshold.
- Rank regions by mean coverage (then by width) and return top_k windows.
Outputs
-------
- out_dir / per_file_segments.csv (Polars CSV)
- out_dir / flat_windows_suggestions.csv (Polars CSV with coverage stats)
- out_dir / flat_windows.json (list[[lo, hi], ...])
- out_dir / coverage_curve.(png|pdf) (plot of coverage vs m/z)
"""
from __future__ import annotations
import json
import re
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
logger = logging.getLogger(__name__)
from joblib import Parallel, delayed
from scipy.signal import medfilt, savgol_filter
from tqdm import tqdm
OUTPUT_DIR = Path("output_files")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------
# Column aliases & helpers
# ---------------------------------------------------------------------
COL_ALIASES = {
"channel": {"channel", "chan", "index", "idx", "Channel", "CHANNEL", "Index"},
"mz": {"m/z", "mz", "mass", "moverz", "m_over_z", "Mass", "MZ"},
"intensity": {"intensity", "counts", "signal", "y", "ion_counts", "Intensity", "COUNTS", "INTENSITY"},
}
def _standardize_columns_pl(df: pl.DataFrame) -> pl.DataFrame:
"""Rename columns to ['channel','mz','intensity'] if aliases are found.
If 'channel' missing, insert a 1-based index. Returns only those 3 columns.
"""
# Map existing cols to lowercase for alias matching
lower = {c: str(c).strip().lower() for c in df.columns}
rename: Dict[str, str] = {}
for std, aliases in COL_ALIASES.items():
for c, cl in lower.items():
if cl in aliases:
rename[c] = std
break
if rename:
df = df.rename(rename)
cols = set(df.columns)
if "mz" not in cols or "intensity" not in cols:
raise KeyError("Input must include 'm/z' (or alias) and 'intensity' columns.")
if "channel" not in cols:
df = df.with_columns(pl.arange(1, df.height + 1).alias("channel"))
# Coerce dtypes and filter invalid rows
def _safe_cast(col: str, dtype) -> pl.Expr:
try:
return pl.col(col).cast(dtype, strict=False)
except TypeError:
# Older Polars versions may not accept strict=; fall back to cast
return pl.col(col).cast(dtype)
df = df.with_columns(
_safe_cast("channel", pl.Int64),
_safe_cast("mz", pl.Float64),
_safe_cast("intensity", pl.Float64),
)
# Keep only required columns; drop NA/inf in mz/intensity
df = df.select("channel", "mz", "intensity")
# Filter NaNs/Infs
df = df.filter(
pl.col("mz").is_finite() & pl.col("intensity").is_finite()
)
return df
def _read_csv_with_fallbacks(path: Path) -> Optional[pl.DataFrame]:
"""Try to read CSV/TSV (and simple whitespace-delimited) with Polars + fallback.
Returns a Polars DataFrame or None.
"""
# First, attempt common separators with Polars
for sep in (",", "\t"):
try:
df = pl.read_csv(
path,
separator=sep,
comment_prefix="#",
infer_schema_length=2048,
ignore_errors=False,
try_parse_dates=False,
)
if df.width >= 2 and df.height >= 1:
return df
except Exception:
pass
# Fallback for simple whitespace-delimited tables via NumPy
# (handles arbitrary runs of whitespace; ignores '#...' comments)
try:
# Read header to get names
with open(path, "r", encoding="utf-8", errors="ignore") as fh:
header = None
for line in fh:
s = line.strip()
if not s or s.startswith("#"):
continue
header = re.split(r"\s+", s)
break
if header is None:
return None
arr = np.genfromtxt(
path,
comments="#",
dtype=None,
names=True,
encoding="utf-8",
autostrip=True,
)
if arr.size == 0:
return None
# Build DataFrame from structured array
cols = {}
names = arr.dtype.names if arr.dtype.names is not None else header
for name in names:
try:
cols[name] = arr[name]
except Exception:
pass
if not cols:
return None
return pl.DataFrame(cols)
except Exception:
return None
[docs]
def read_spectrum_table(path: Union[str, Path]) -> pl.DataFrame:
"""Robust reader that returns Polars DataFrame with standardized columns.
Tries comma, tab, then whitespace-delimited tables (with '#' comments).
"""
path = Path(path)
df = _read_csv_with_fallbacks(path)
if df is None:
raise ValueError(f"Could not parse table at: {path}")
# Sort by m/z (derivatives use x order)
df = _standardize_columns_pl(df).sort("mz")
return df
# ---------------------------------------------------------------------
# Smoothing / flat-segment detection (NumPy/SciPy)
# ---------------------------------------------------------------------
[docs]
@dataclass
class FlatParams:
y_quantile: float = 0.20
grad_quantile: float = 0.40
curv_quantile: float = 0.40
savgol_window: int = 11
savgol_poly: int = 2
min_width: float = 0.2
min_points: int = 20
def _adaptive_savgol(y: np.ndarray, window: int, poly: int) -> np.ndarray:
n = y.size
if n < 5:
return y.copy()
# window must be odd and <= n
if window > n:
window = n if n % 2 == 1 else n - 1
if window < poly + 2:
window = poly + 3 # ensure > poly and odd
if window % 2 == 0:
window += 1
try:
return savgol_filter(y, window_length=window, polyorder=poly, mode="interp")
except Exception:
# fallback to a small median filter
k = min(max(5, window), n if n % 2 == 1 else n - 1)
return medfilt(y, kernel_size=k)
[docs]
def find_flat_segments(
x: np.ndarray, y: np.ndarray, p: FlatParams
) -> List[Tuple[float, float, int]]:
"""
Return list of (lo, hi, n_points) flat segments for one spectrum.
"""
x = np.asarray(x, float).ravel()
y = np.asarray(y, float).ravel()
assert x.shape == y.shape
# Collapse duplicate m/z values to avoid zero-spacing gradients.
uniq_x, inverse = np.unique(x, return_inverse=True)
if uniq_x.size != x.size:
accum = np.zeros_like(uniq_x, dtype=float)
counts = np.zeros_like(uniq_x, dtype=int)
np.add.at(accum, inverse, y)
np.add.at(counts, inverse, 1)
y = accum / np.maximum(counts, 1)
x = uniq_x
# Smooth intensity
y_s = _adaptive_savgol(y, p.savgol_window, p.savgol_poly)
# derivatives w.r.t m/z (handles nonuniform spacing)
dy = np.gradient(y_s, x, edge_order=2)
d2 = np.gradient(dy, x, edge_order=2)
# thresholds (lower quantiles → flat)
y_thr = np.nanquantile(y, p.y_quantile)
g_thr = np.nanquantile(np.abs(dy), p.grad_quantile)
c_thr = np.nanquantile(np.abs(d2), p.curv_quantile)
mask = (y <= y_thr) & (np.abs(dy) <= g_thr) & (np.abs(d2) <= c_thr)
# group contiguous True regions
segs: List[Tuple[float, float, int]] = []
if not mask.any():
return segs
idx = np.where(mask)[0]
# break into contiguous runs
runs = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
for r in runs:
if r.size < p.min_points:
continue
lo = float(x[r[0]]); hi = float(x[r[-1]])
if (hi - lo) >= p.min_width:
segs.append((lo, hi, int(r.size)))
return segs
# ---------------------------------------------------------------------
# Aggregate across files
# ---------------------------------------------------------------------
[docs]
@dataclass
class AggregateParams:
bin_width: float = 0.1
coverage_threshold: float = 0.5
top_k: int = 6
[docs]
def aggregate_common_windows(
segments_by_file: Dict[str, List[Tuple[float, float, int]]],
x_minmax: Tuple[float, float],
agg: AggregateParams,
) -> Tuple[List[Tuple[float, float]], pl.DataFrame]:
"""
Merge per-file segments into common windows via m/z bin coverage.
Returns (windows, coverage_table_df[polars]).
"""
files = list(segments_by_file.keys())
n_files = len(files)
x_min, x_max = x_minmax
if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
return [], pl.DataFrame({"mz_center": [], "coverage": []})
# Build bin edges and centers
edges = np.arange(x_min, x_max + agg.bin_width, agg.bin_width, dtype=float)
if edges.size < 2:
return [], pl.DataFrame({"mz_center": [], "coverage": []})
centers = 0.5 * (edges[:-1] + edges[1:])
cover_counts = np.zeros(centers.size, dtype=int)
# Vectorized: mark bins covered by segments for each file
for fname, segs in segments_by_file.items():
if not segs:
continue
# Stack all segment bounds and use broadcasting
seg_arr = np.array([(lo, hi) for lo, hi, _n in segs])
# centers (n_bins,) vs seg_arr (n_segs, 2): check if center is in any segment
in_any = np.any(
(centers[np.newaxis, :] >= seg_arr[:, 0:1]) &
(centers[np.newaxis, :] <= seg_arr[:, 1:2]),
axis=0
)
cover_counts += in_any.astype(int)
coverage = cover_counts / max(n_files, 1)
# Identify runs with coverage ≥ threshold
mask = coverage >= float(agg.coverage_threshold)
windows: List[Tuple[float, float]] = []
if mask.any():
idx = np.where(mask)[0]
runs = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
for r in runs:
lo = float(edges[r[0]])
hi = float(edges[r[-1] + 1])
if (hi - lo) > 0:
windows.append((lo, hi))
# Rank windows by mean coverage, pick top_k
stats: List[Tuple[float, float, float, float, float]] = []
for (lo, hi) in windows:
sel = (centers >= lo) & (centers <= hi)
cov_mean = float(np.mean(coverage[sel])) if np.any(sel) else 0.0
cov_min = float(np.min(coverage[sel])) if np.any(sel) else 0.0
stats.append((lo, hi, hi - lo, cov_mean, cov_min))
if stats:
df = pl.DataFrame(
stats,
schema=["lo", "hi", "width", "coverage_mean", "coverage_min"],
orient="row",
).sort(by=["coverage_mean", "width"], descending=[True, True])
if agg.top_k and agg.top_k > 0:
df = df.head(agg.top_k)
windows = [(float(r[0]), float(r[1])) for r in df.select(["lo", "hi"]).to_numpy()]
else:
df = pl.DataFrame(schema=["lo", "hi", "width", "coverage_mean", "coverage_min"])
# coverage table (for plotting/export)
coverage_df = pl.DataFrame({"mz_center": centers, "coverage": coverage})
return windows, coverage_df
# ---------------------------------------------------------------------
# Orchestration
# ---------------------------------------------------------------------
def _has_glob_chars(s: str) -> bool:
return any(ch in s for ch in "*?[")
[docs]
@dataclass
class ScanForFlatRegion:
files: List[Union[str, Path]] = field(default_factory=list)
out_dir: Union[str, Path] = "flat_windows_out"
n_jobs: int = -1
flat_params: FlatParams = field(default_factory=FlatParams)
agg_params: AggregateParams = field(default_factory=AggregateParams)
auto_tune: bool = False
def _expand_files(self) -> List[Path]:
paths: List[Path] = []
for item in self.files:
s = str(item).strip()
if not s:
continue
if _has_glob_chars(s):
for hit in Path().glob(s):
if hit.is_file():
paths.append(hit.resolve())
continue
p = Path(s)
if p.is_file():
paths.append(p.resolve())
elif p.is_dir():
for hit in p.rglob("*"):
if hit.is_file():
paths.append(hit.resolve())
else:
for hit in Path().glob(s):
if hit.is_file():
paths.append(hit.resolve())
uniq = sorted(set(paths))
if not uniq:
raise FileNotFoundError(
"No input files found; verify the provided file paths or glob patterns."
)
return uniq
[docs]
def run(self):
files = self._expand_files()
if self.auto_tune:
from ..adaptive import estimate_flat_params
from dataclasses import replace as _replace
overrides = estimate_flat_params([str(f) for f in files])
if overrides:
self.flat_params = _replace(self.flat_params, **overrides)
out_dir = OUTPUT_DIR / Path(self.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
def _worker(path: Path):
try:
df = read_spectrum_table(path)
x = df.get_column("mz").to_numpy()
y = df.get_column("intensity").to_numpy()
x = np.asarray(x, dtype=float)
y = np.asarray(y, dtype=float)
segs = find_flat_segments(x, y, self.flat_params)
return (
str(path),
segs,
(np.min(x) if x.size else np.nan),
(np.max(x) if x.size else np.nan),
)
except Exception:
return str(path), [], np.nan, np.nan
results = Parallel(n_jobs=self.n_jobs, prefer="threads")(
delayed(_worker)(p) for p in tqdm(files, desc="Processing spectra")
)
segments_by_file: Dict[str, List[Tuple[float, float, int]]] = {}
xmins: List[float] = []
xmaxs: List[float] = []
rows: List[Tuple[str, float, float, float, int]] = []
for fname, segs, xmin, xmax in results:
segments_by_file[fname] = segs
if np.isfinite(xmin):
xmins.append(float(xmin))
if np.isfinite(xmax):
xmaxs.append(float(xmax))
for lo, hi, n in segs:
rows.append((fname, lo, hi, hi - lo, n))
per_file_df = pl.DataFrame(
rows,
schema=["file", "lo", "hi", "width", "n_pts"],
orient="row",
).sort(by=["file", "lo"])
per_file_df.write_csv(out_dir / "per_file_segments.csv")
x_min = float(np.nanmin(xmins)) if xmins else np.nan
x_max = float(np.nanmax(xmaxs)) if xmaxs else np.nan
windows, coverage_df = aggregate_common_windows(
segments_by_file,
(x_min, x_max),
self.agg_params,
)
stats: List[dict] = []
for i, (lo, hi) in enumerate(windows, start=1):
sub = coverage_df.filter(
(pl.col("mz_center") >= lo) & (pl.col("mz_center") <= hi)
)
if sub.height:
stats.append(
{
"window_id": i,
"lo": lo,
"hi": hi,
"width": hi - lo,
"coverage_mean": float(sub["coverage"].mean()),
"coverage_min": float(sub["coverage"].min()),
}
)
sug_df = (
pl.DataFrame(
stats,
schema=[
"window_id",
"lo",
"hi",
"width",
"coverage_mean",
"coverage_min",
],
)
if stats
else pl.DataFrame(
schema=[
"window_id",
"lo",
"hi",
"width",
"coverage_mean",
"coverage_min",
]
)
)
sug_df.write_csv(out_dir / "flat_windows_suggestions.csv")
with open(out_dir / "flat_windows.json", "w", encoding="utf-8") as f:
json.dump([[float(lo), float(hi)] for (lo, hi) in windows], f, indent=2)
plt.figure(figsize=(10, 4.2))
x_cov = coverage_df["mz_center"].to_numpy() if coverage_df.height else np.array([])
y_cov = coverage_df["coverage"].to_numpy() if coverage_df.height else np.array([])
if x_cov.size:
plt.plot(x_cov, y_cov, lw=1)
for lo, hi in windows:
plt.axvspan(lo, hi, alpha=0.25)
plt.xlabel("m/z")
plt.ylabel("Coverage fraction")
plt.title("Flat-window coverage across spectra (suggested windows shaded)")
plt.tight_layout()
for ext in (".pdf", ".png"):
plt.savefig(out_dir / f"coverage_curve{ext}", dpi=300, bbox_inches="tight")
plt.close()
logger.info("Suggested flat windows (lo, hi):")
for i, (lo, hi) in enumerate(windows, start=1):
logger.info(f" {i:2d}. [{lo:.4f}, {hi:.4f}] (width={hi - lo:.4f})")
return windows