🧬 Temporal Genomics — Hands-on Workshop (Python / Jupyter)¶
1️⃣ Workshop Overview:¶
In this workshop, you will learn to analyse temporal genomic data — data from the same population sampled at two different time points.
Explore allele frequency changes over time (one population: Has02 → Has23)
What you'll learn
- Load temporal genomic data for mutations of different functional effects
- Build and visualize 2D Site Frequency Spectra (2dSFS) with
dadiacross mutation effect classes - Quantify frequency shifts between timepoints
- Compare mutation categories (high-, low-, moderate-, and synonymous-effect)
- Compute per-row conditional distributions (p(j \mid i))
- Perform bootstrap-based confidence intervals, statistical tests (χ², Jensen–Shannon divergence). and permutation tests
- Fit a simple neutral demographic model and inspect residuals
- Interpret results and link them to evolutionary processes
This notebook is ready-to-run in a conda environment with
dadi,numpy,matplotlib,pandas,scipy, andseaborninstalled.
2️⃣ Load Required Libraries¶
Load essential Python packages for numerical analysis, statistics, and plotting.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import chisquare # Chi-square test
from scipy.spatial.distance import jensenshannon # JSD divergence
from matplotlib.colors import TwoSlopeNorm, SymLogNorm
from matplotlib.ticker import MaxNLocator
import dadi
3️⃣ Load temporal genomic data (dadi data dictionaries)¶
- We’ll use pre-annotated SNPs separated by functional effect (from snpEff).
- Replace the file paths below with your own
dadi-formatted.vcf.datafiles exported from your pipeline (these are simple text-based dictionaries expected bydadi.Misc.make_data_dict).
We treat two timepoints for the same population: Has02 (older) and Has23 (newer).
Mutation effect classes (example file naming convention used here):
snpEff_Ann_HighEff_...snpEff_Ann_LowEff_...snpEff_Ann_ModEff_...snpEff_Ann_SynVar_...
# === Replace these paths with your files ===
snp_syn = 'snpEff_Ann_SynVar_merged_with_header_ChromModified.vcf.data'
snp_high = 'snpEff_Ann_HighEff_merged_with_header_ChromModified.vcf.data'
snp_low = 'snpEff_Ann_LowEff_merged_with_header_ChromModified.vcf.data'
snp_mod = 'snpEff_Ann_ModEff_merged_with_header_ChromModified.vcf.data'
# Create dadi data dictionaries (these parse the simple 'vcf.data' files)
dd_syn = dadi.Misc.make_data_dict(snp_syn)
dd_high = dadi.Misc.make_data_dict(snp_high)
dd_low = dadi.Misc.make_data_dict(snp_low)
dd_mod = dadi.Misc.make_data_dict(snp_mod)
print('data dicts created: syn, high, low, mod')
data dicts created: syn, high, low, mod
4️⃣ Build 2D Site Frequency Spectra (2dSFS) for Has02 → Has23¶
- We generate polarized (ancestral/derived known) 2dSFS objects using
dadi.Spectrum.from_data_dict. - Set
projectionsto sensible integer projections for each time point (number of chromosomes retained after down-projection). - The 2dSFS shows the joint distribution of derived allele frequencies in Has02 vs Has23.
# === Build spectra (adjust projections to your dataset) ===
proj = [24, 40] # example projections: replace if needed
Has_syn = dadi.Spectrum.from_data_dict(dd_syn, pop_ids=['Has02','Has23'], projections=proj, polarized=True)
Has_high = dadi.Spectrum.from_data_dict(dd_high, pop_ids=['Has02','Has23'], projections=proj, polarized=True)
Has_low = dadi.Spectrum.from_data_dict(dd_low, pop_ids=['Has02','Has23'], projections=proj, polarized=True)
Has_mod = dadi.Spectrum.from_data_dict(dd_mod, pop_ids=['Has02','Has23'], projections=proj, polarized=True)
print('Spectra shapes (rows = i in Has02, cols = j in Has23):')
print('Has_high', Has_high.data.shape, 'total sites =', Has_high.data.sum())
print('Has_low ', Has_low.data.shape, 'total sites =', Has_low.data.sum())
print('Has_mod ', Has_mod.data.shape, 'total sites =', Has_mod.data.sum())
print('Has_syn ', Has_syn.data.shape, 'total sites =', Has_syn.data.sum())
Spectra shapes (rows = i in Has02, cols = j in Has23): Has_high (25, 41) total sites = 25972.0 Has_low (25, 41) total sites = 216518.0 Has_mod (25, 41) total sites = 243912.0 Has_syn (25, 41) total sites = 185316.0
5️⃣ Visualize the 2dSFS (all effect classes)¶
We use
dadi.Plotting.plot_single_2d_sfs. Thevmin/vmaxare chosen robustly to avoid invalid colorbar errors.Each pixel represents the number of SNPs with a given derived allele count (i, j) in Has02 and Has23.
# === Visualize four 2dSFS side-by-side ===
all_specs = [Has_high, Has_low, Has_mod, Has_syn]
nonempty = [d for d in all_specs if getattr(d,'data',None) is not None and d.data.size>0]
if len(nonempty)==0:
print('No spectra with data found.')
else:
vmin = max(1, min([d.data.min() for d in nonempty]))
vmax = max([d.data.max() for d in nonempty])
fig, axes = plt.subplots(1, 4, figsize=(22,5))
cmap = 'turbo'
dadi.Plotting.plot_single_2d_sfs(Has_high, vmin=vmin, vmax=vmax, cmap=cmap, pop_ids=('Has02','Has23'), ax=axes[0])
dadi.Plotting.plot_single_2d_sfs(Has_low, vmin=vmin, vmax=vmax, cmap=cmap, pop_ids=('Has02','Has23'), ax=axes[1])
dadi.Plotting.plot_single_2d_sfs(Has_mod, vmin=vmin, vmax=vmax, cmap=cmap, pop_ids=('Has02','Has23'), ax=axes[2])
dadi.Plotting.plot_single_2d_sfs(Has_syn, vmin=vmin, vmax=vmax, cmap=cmap, pop_ids=('Has02','Has23'), ax=axes[3])
titles = ['High-effect', 'Low-effect', 'Moderate-effect', 'Synonymous']
for ax, t in zip(axes, titles):
ax.set_title(t, fontsize=14)
ax.xaxis.label.set_size(12)
ax.yaxis.label.set_size(12)
ax.tick_params(axis='both', which='major', labelsize=10)
plt.tight_layout()
plt.show()
6️⃣ Compare Allele Frequency Categories (Rare vs Moderate bins from a sdSFS)¶
- We can compare bins of rare vs moderate frequency alleles to detect potential selection effects.
- We'll define a helper to extract counts in rare and moderate bins (you can adjust thresholds).
def extract_allele_bins(sfs, rare_max=2, moderate_min=3, moderate_max=8):
arr = sfs.data
rare = []
moderate = []
n1, n2 = arr.shape
for i in range(n1):
for j in range(n2):
if i==0 and j==0:
continue
count = arr[i,j]
if (i<=rare_max or j<=rare_max) and (i>0 or j>0):
rare.append(count)
elif (moderate_min <= i <= moderate_max) and (moderate_min <= j <= moderate_max):
moderate.append(count)
return np.array(rare), np.array(moderate)
rare_syn, mod_syn = extract_allele_bins(Has_syn)
rare_high, mod_high = extract_allele_bins(Has_high)
rare_low, mod_low = extract_allele_bins(Has_low)
rare_mod, mod_mod = extract_allele_bins(Has_mod)
print('Example counts lengths: rare_high=', len(rare_high), 'mod_high=', len(mod_high))
Example counts lengths: rare_high= 188 mod_high= 36
6.1. Boxplots of rare/moderate bin distributions¶
data = [rare_syn, rare_high, rare_low, rare_mod, mod_syn, mod_high, mod_low, mod_mod]
labels = ['Rare Syn','Rare High','Rare Low','Rare Mod','Mod Syn','Mod High','Mod Low','Mod Mod']
plt.figure(figsize=(12,6))
# avoid empty groups
data_nonempty = [d[d>0] if d.size>0 else np.array([0]) for d in data]
plt.boxplot(data_nonempty, labels=labels, showfliers=False)
plt.ylabel('Site counts in bin')
plt.title('Distribution of rare and moderate allele-count bins (Has population)')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
6.2. Density overlays (rare bins): High vs Low effect¶
We can visualise the shape of these distributions using Kernel Density Estimates (KDE).
import seaborn as sns
plt.figure(figsize=(8,5))
if np.any(rare_high>0):
sns.kdeplot(rare_high[rare_high>0], label='Rare High effect', color='red', lw=2)
if np.any(rare_low>0):
sns.kdeplot(rare_low[rare_low>0], label='Rare Low effect', color='blue', lw=2)
plt.xlabel('Allele count per bin')
plt.ylabel('Density')
plt.title('Distribution of rare allele counts (Has population)')
plt.legend()
plt.tight_layout()
plt.show()
7️⃣ Statistical Metrics per Row of the 2dSFS¶
- We now compare rows (fixed Has02 frequency, varying Has23) between mutation classes - per-row normalized conditional distributions (p(j \mid i))
We compute row-wise normalized probabilities and show panel plots for rows with sufficient data. This helps visualize how distribution of derived allele counts in Has23 shifts relative to Has02 for each starting i.
We calculate:
- χ² statistic comparing histograms
- Jensen–Shannon divergence (JSD)
- Mean shifts and fraction of mass left of the diagonal
7.1. Helper functions for statistics and normalisation:¶
# -------------------------
# 2. Helper functions
# -------------------------
def chisq_on_hist(c1, c2, min_per_bin=5):
c1 = np.asarray(c1, float)
c2 = np.asarray(c2, float)
mask = (c1 + c2) >= min_per_bin
if mask.sum() < 2:
return np.nan, np.nan
c1m, c2m = c1[mask], c2[mask]
total = (c1m.sum() + c2m.sum()) / 2.0
if c1m.sum() == 0 or c2m.sum() == 0:
return np.nan, np.nan
c1s = c1m * (total / c1m.sum())
c2s = c2m * (total / c2m.sum())
stat, p = chisquare(f_obs=c1s, f_exp=c2s)
return stat, p
def normalize_row(c):
c = np.asarray(c, float)
s = c.sum()
return c / s if s > 0 else np.zeros_like(c)
def row_metrics(sfs_high, sfs_low, row_index):
cH = sfs_high.data[row_index, :].astype(float)
cL = sfs_low.data[row_index, :].astype(float)
pH = normalize_row(cH)
pL = normalize_row(cL)
chi2, p_chi = chisq_on_hist(cH, cL)
try:
jsd = jensenshannon(pH, pL, base=2.0)
except Exception:
jsd = np.nan
j = np.arange(cH.size)
deltaH = ((j - row_index) * cH).sum() / cH.sum() if cH.sum() > 0 else np.nan
deltaL = ((j - row_index) * cL).sum() / cL.sum() if cL.sum() > 0 else np.nan
fracLeftH = cH[j < row_index].sum() / cH.sum() if cH.sum() > 0 else np.nan
fracLeftL = cL[j < row_index].sum() / cL.sum() if cL.sum() > 0 else np.nan
mean_j_high = (j * pH).sum() if pH.sum() > 0 else np.nan
mean_j_low = (j * pL).sum() if pL.sum() > 0 else np.nan
return dict(row=row_index,
n_high=int(cH.sum()), n_low=int(cL.sum()),
chi2=chi2, p_chi=p_chi, jsd=jsd,
delta_mean_high=deltaH, delta_mean_low=deltaL,
frac_left_high=fracLeftH, frac_left_low=fracLeftL,
mean_j_high=mean_j_high, mean_j_low=mean_j_low)
def bootstrap_row_diffs(cH, cL, row_index, B=1000, rng=None):
if rng is None:
rng = np.random.default_rng(42)
def _safe_mean_j(counts):
counts = np.asarray(counts, float)
tot = counts.sum()
if tot <= 0:
return np.nan
j = np.arange(counts.size)
return (j * (counts / tot)).sum()
def _safe_E_delta(counts, row_index):
counts = np.asarray(counts, float)
tot = counts.sum()
if tot <= 0:
return np.nan
j = np.arange(counts.size)
delta = j - row_index
return (delta * counts).sum() / tot
def _safe_frac_left(counts, row_index):
counts = np.asarray(counts, float)
tot = counts.sum()
if tot <= 0:
return np.nan
j = np.arange(counts.size)
return counts[j < row_index].sum() / tot
d_meanj = _safe_mean_j(cL) - _safe_mean_j(cH)
d_Ed = _safe_E_delta(cL, row_index) - _safe_E_delta(cH, row_index)
d_fl = _safe_frac_left(cL, row_index) - _safe_frac_left(cH, row_index)
bmj, bEd, bfl = [], [], []
for _ in range(B):
bH = rng.poisson(np.maximum(cH, 0))
bL = rng.poisson(np.maximum(cL, 0))
bmj.append(_safe_mean_j(bL) - _safe_mean_j(bH))
bEd.append(_safe_E_delta(bL, row_index) - _safe_E_delta(bH, row_index))
bfl.append(_safe_frac_left(bL, row_index) - _safe_frac_left(bH, row_index))
def ci(arr):
return np.nanpercentile(arr, 2.5), np.nanpercentile(arr, 97.5)
lo_mj, hi_mj = ci(bmj)
lo_Ed, hi_Ed = ci(bEd)
lo_fl, hi_fl = ci(bfl)
return dict(row=row_index,
d_meanj=d_meanj, lo_meanj=lo_mj, hi_meanj=hi_mj,
d_Ed=d_Ed, lo_Ed=lo_Ed, hi_Ed=hi_Ed,
d_fl=d_fl, lo_fl=lo_fl, hi_fl=hi_fl)
# =========================
# Per-row normalized p(j|i) panel plots
# =========================
def analyze_sfs_rows(
sfs_high, sfs_low,
pop_label="Pop",
start_row=1, end_row=None,
ncols=4,
ymax_prob=0.12,
mean_lines=True,
pop1_name=None, pop2_name=None,
min_row_sites=10,
adaptive_ylim=True,
show_equal_freq=True,
main_title=None,
subtitle=None,
top_pad=0.88,
title_y=0.93,
subtitle_y=0.90,
title_kwargs=None,
subtitle_kwargs=None,
subplot_title_fontsize=10
):
n_y, n_x = sfs_high.data.shape
assert sfs_low.data.shape == (n_y, n_x), "SFS shapes must match"
n1 = n_y - 1
n2 = n_x - 1
if pop1_name is None:
pop1_name = f"{pop_label}02"
if pop2_name is None:
pop2_name = f"{pop_label}23"
if end_row is None:
end_row = n_y - 1
rows = np.arange(start_row, end_row + 1, dtype=int)
tot_high = sfs_high.data.sum(axis=1)
tot_low = sfs_low.data.sum(axis=1)
metrics = [row_metrics(sfs_high, sfs_low, i) for i in rows]
df = pd.DataFrame(metrics).sort_values("row")
n_panels = len(rows)
nrows = int(np.ceil(n_panels / ncols))
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4 * ncols, 2.8 * nrows),
sharex=True, sharey=False, constrained_layout=False)
axes = np.atleast_1d(axes).ravel()
legend_shown = False
for ax, i in zip(axes, rows):
cH = sfs_high.data[i, :].astype(float)
cL = sfs_low.data[i, :].astype(float)
nH = int(tot_high[i])
nL = int(tot_low[i])
if (nH + nL) < min_row_sites or (cH.sum() == 0 and cL.sum() == 0):
ax.text(0.5, 0.5, f"Row {i}\nno/too few sites\n(high={nH}, low={nL})",
ha='center', va='center', transform=ax.transAxes)
ax.set_axis_off()
continue # Skip setting title and labels for this axis
pH = normalize_row(cH)
pL = normalize_row(cL)
x = np.arange(n_x)
ax.plot(x, pH, 'r-', lw=1.8, alpha=0.95, label='High (norm)' if not legend_shown else None)
ax.plot(x, pL, 'b--', lw=1.8, alpha=0.95, label='Low (norm)' if not legend_shown else None)
if show_equal_freq and n1 > 0:
j_eq = (n2 / n1) * i
ax.axvline(j_eq, color='k', ls=':', lw=1.2, alpha=0.8,
label='Equal frequency (p02=p23)' if not legend_shown else None)
if mean_lines:
mjh = df.loc[df.row == i, 'mean_j_high'].values[0]
mjl = df.loc[df.row == i, 'mean_j_low'].values[0]
if np.isfinite(mjh):
ax.axvline(mjh, color='red', ls='--', lw=1.2, alpha=0.95,
label='High mean j' if not legend_shown else None)
if np.isfinite(mjl):
ax.axvline(mjl, color='blue', ls='--', lw=1.2, alpha=0.95,
label='Low mean j' if not legend_shown else None)
if adaptive_ylim:
y_max = max(pH.max(), pL.max())
ax.set_ylim(0, max(y_max * 1.15, 0.02))
else:
ax.set_ylim(0, ymax_prob)
mrow = df[df.row == i].iloc[0]
ax.set_title(f"Row {i} | sites: high={nH}, low={nL}\nJSD={mrow.jsd:.3f} | χ² p={mrow.p_chi:.2g}",
fontsize=subplot_title_fontsize)
ax.set_xlabel(f"{pop2_name} derived allele count (j)")
ax.set_ylabel(f"p(j | {pop1_name} = {i})")
if not legend_shown:
ax.legend(loc="upper right", fontsize=8, frameon=False)
legend_shown = True
for k in range(len(rows), len(axes)):
fig.delaxes(axes[k])
fig.tight_layout(rect=[0, 0, 1, top_pad])
if main_title:
tkw = dict(fontsize=14, fontweight='bold')
if title_kwargs:
tkw.update(title_kwargs)
fig.text(0.5, title_y, main_title, ha='center', va='top', **tkw)
if subtitle:
skw = dict(fontsize=11, color='gray')
if subtitle_kwargs:
skw.update(subtitle_kwargs)
fig.text(0.5, subtitle_y, subtitle, ha='center', va='top', **skw)
plt.show()
return df
7.2. Row-Wise Normalized p(j|i) Plots¶
These plots show how the allele frequency in Has23 depends on the starting frequency in Has02.
# =========================
# Call the per-row panel plots:
# =========================
df_has_rows = analyze_sfs_rows(
Has_high, Has_low,
pop_label="Has",
start_row=1,
end_row=None,
ncols=4,
ymax_prob=0.12,
mean_lines=True,
pop1_name="Has02",
pop2_name="Has23",
min_row_sites=10,
adaptive_ylim=True,
show_equal_freq=True,
main_title="Has: normalized p(j|i) by row (high=red, low=blue)",
subtitle="Dotted = equal-frequency j=(n23/n02)·i; dashed = mean j per class",
top_pad=0.88,
title_y=0.93,
subtitle_y=0.90,
subplot_title_fontsize=10
)
df_has_rows.head()
| row | n_high | n_low | chi2 | p_chi | jsd | delta_mean_high | delta_mean_low | frac_left_high | frac_left_low | mean_j_high | mean_j_low | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 1380 | 11143 | 68.109648 | 3.695666e-07 | 0.049086 | 1.086232 | 1.161088 | 0.489855 | 0.492955 | 2.086232 | 2.161088 |
| 1 | 2 | 980 | 7962 | 191.153944 | 6.062290e-29 | 0.107862 | 1.243878 | 1.430168 | 0.450000 | 0.423888 | 3.243878 | 3.430168 |
| 2 | 3 | 684 | 5534 | 75.323687 | 3.321765e-07 | 0.077494 | 2.238304 | 2.159740 | 0.358187 | 0.346043 | 5.238304 | 5.159740 |
| 3 | 4 | 510 | 4398 | 106.428199 | 1.070263e-11 | 0.093171 | 2.262745 | 2.552751 | 0.364706 | 0.337426 | 6.262745 | 6.552751 |
| 4 | 5 | 448 | 3546 | 156.261520 | 3.714352e-20 | 0.130380 | 2.095982 | 2.798928 | 0.348214 | 0.322053 | 7.095982 | 7.798928 |
8️⃣ Bootstrapped Confidence Intervals¶
Summary plots with bootstrap confidence intervals: estimate uncertainty in per-row frequency shifts using Poisson bootstrapping.
- We compute differences between low- and high-effect classes per-row (mean shifts) with bootstrap CI via Poisson resampling.
- This gives a compact summary and uncertainty.
def summary_plot_with_ci(sfs_high, sfs_low, rows=None, B=1000, seed=123,
title="Summary with 95% bootstrap CI",
pop1_name="Pop02", pop2_name="Pop23"):
rng = np.random.default_rng(seed)
n_y, _ = sfs_high.data.shape
if rows is None:
rows = np.arange(1, n_y)
rows = np.array(sorted(rows))
recs = []
for i in rows:
cH = sfs_high.data[i, :].astype(float)
cL = sfs_low.data[i, :].astype(float)
recs.append(bootstrap_row_diffs(cH, cL, i, B=B, rng=rng))
df = pd.DataFrame(recs)
fig, axs = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)
axs[0].plot(df.row, df.d_meanj, 'k-o', lw=1.5, ms=4)
axs[0].fill_between(df.row, df.lo_meanj, df.hi_meanj, color='k', alpha=0.15)
axs[0].axhline(0, color='gray', ls=':')
axs[0].set_xlabel(f"Row ({pop1_name} i)")
axs[0].set_ylabel("Δ mean j (low − high)")
axs[0].set_title(f"Mean {pop2_name} shift")
axs[1].plot(df.row, df.d_Ed, 'm-o', lw=1.5, ms=4)
axs[1].fill_between(df.row, df.lo_Ed, df.hi_Ed, color='m', alpha=0.15)
axs[1].axhline(0, color='gray', ls=':')
axs[1].set_xlabel(f"Row ({pop1_name} i)")
axs[1].set_ylabel("Δ E[Δ=j−i] (low − high)")
axs[1].set_title("Shift relative to diagonal")
axs[2].plot(df.row, df.d_fl, 'b-o', lw=1.5, ms=4)
axs[2].fill_between(df.row, df.lo_fl, df.hi_fl, color='b', alpha=0.15)
axs[2].axhline(0, color='gray', ls=':')
axs[2].set_xlabel(f"Row ({pop1_name} i)")
axs[2].set_ylabel("Δ P(Δ<0) (low − high)")
axs[2].set_title("Mass left of j=i")
fig.suptitle(title, fontsize=13)
plt.show()
return df
# =========================
# Call the summary plots:
# =========================
df_has_summary = summary_plot_with_ci(
Has_high, Has_low,
title="Haselbrunn: summary with 95% bootstrap CI (low − high)",
pop1_name="Has02",
pop2_name="Has23"
)
df_has_summary.head()
| row | d_meanj | lo_meanj | hi_meanj | d_Ed | lo_Ed | hi_Ed | d_fl | lo_fl | hi_fl | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.074856 | -0.098706 | 0.257185 | 0.074856 | -0.098706 | 0.257185 | 0.003100 | -0.023350 | 0.030916 |
| 1 | 2 | 0.186291 | -0.077791 | 0.462287 | 0.186291 | -0.077791 | 0.462287 | -0.026112 | -0.060767 | 0.006557 |
| 2 | 3 | -0.078564 | -0.450153 | 0.291594 | -0.078564 | -0.450153 | 0.291594 | -0.012144 | -0.051817 | 0.024340 |
| 3 | 4 | 0.290006 | -0.170093 | 0.745835 | 0.290006 | -0.170093 | 0.745835 | -0.027280 | -0.071217 | 0.017402 |
| 4 | 5 | 0.702946 | 0.131769 | 1.210426 | 0.702946 | 0.131769 | 1.210426 | -0.026161 | -0.072121 | 0.022054 |
9️⃣ Fit a simple neutral model (split with symmetric migration) to the low-effect 2dSFS¶
This step fits a pragmatic 4-parameter model using dadi.Inference.optimize_log to the low-effect class (used as a neutral baseline), then computes Anscombe residuals to compare observed vs model expectations.
# =========================
# 4. --- Fit neutral demographic model (split with symmetric migration)---
# =========================
def fit_split_mig_sym_on_low(data_low, maxiter=5):
ns = data_low.sample_sizes
max_n = int(np.max(ns))
pts_l = [max_n + 10, max_n + 20, max_n + 30]
def split_mig_sym(params, ns, pts):
return dadi.Demographics2D.split_mig(params, ns, pts)
func_ex = dadi.Numerics.make_extrap_log_func(split_mig_sym)
p0 = np.array([1.0, 1.0, 0.02, 1.0])
lower = [1e-3, 1e-3, 1e-5, 0.0]
upper = [50.0, 50.0, 5.0, 50.0]
p_opt, ll_opt = None, -np.inf
for _ in range(maxiter):
p_start = dadi.Misc.perturb_params(p0, fold=1, lower_bound=lower, upper_bound=upper)
try:
p_try = dadi.Inference.optimize_log(
p_start, data_low, func_ex, pts_l,
lower_bound=lower, upper_bound=upper,
verbose=0, maxiter=300
)
mod_try = func_ex(p_try, ns, pts_l)
ll_try = dadi.Inference.ll_multinom(mod_try, data_low)
if ll_try > ll_opt:
p_opt, ll_opt = p_try, ll_try
except Exception:
continue
if p_opt is None:
p_opt = p0
model = func_ex(p_opt, ns, pts_l)
ll_opt = dadi.Inference.ll_multinom(model, data_low)
else:
model = func_ex(p_opt, ns, pts_l)
theta = dadi.Inference.optimal_sfs_scaling(model, data_low)
model = model * theta
return p_opt, model, ll_opt, theta, pts_l
9.1. Compute Anscombe & residual heatmaps (observed vs model)¶
# Compute Anscombe residuals
def poisson_anscombe_residual_array(model_sfs, data_sfs):
O = np.asarray(data_sfs.data, float)
M = np.asarray(model_sfs.data, float)
R = 2.0 * (np.sqrt(O + 3.0/8.0) - np.sqrt(M + 3.0/8.0))
R[~np.isfinite(R)] = 0.0
return R
# Plot residual heatmaps
def plot_residual_heatmaps(obs_low, obs_high, model_neutral, pop_label="Pop",
cmap="RdBu_r", vclip_percent=96, scale="symlog",
linthresh=1.0, exclude_edges=True, edge_pad=3,
corner_crop=2, use_mad=True, sigma=3.5):
R_low = poisson_anscombe_residual_array(model_neutral, obs_low)
R_high = poisson_anscombe_residual_array(model_neutral, obs_high)
n_y, n_x = R_low.shape
n1 = n_y - 1
n2 = n_x - 1
def _vals_for_scale(R):
V = R.copy()
if exclude_edges:
V[:edge_pad, :] = np.nan
V[-edge_pad:, :] = np.nan
V[:, :edge_pad] = np.nan
V[:, -edge_pad:] = np.nan
if corner_crop and corner_crop > 0:
cc = corner_crop
V[:cc, :cc] = np.nan
V[-cc:, -cc:] = np.nan
return V.ravel()
vals = np.concatenate([_vals_for_scale(R_low), _vals_for_scale(R_high)])
vals = vals[np.isfinite(vals)]
if vals.size == 0:
vmax = 1.0
else:
perc_v = np.nanpercentile(np.abs(vals), vclip_percent)
if use_mad:
mad = np.nanmedian(np.abs(vals - np.nanmedian(vals)))
robust_sigma = 1.4826 * mad
v_mad = max(sigma * robust_sigma, 1e-6)
vmax = min(max(v_mad, 1e-6), perc_v)
else:
vmax = perc_v
if vmax <= 0:
vmax = np.nanmax(np.abs(vals)) or 1.0
norm = SymLogNorm(linthresh=linthresh, vmin=-vmax, vmax=vmax, base=10) if scale == "symlog" else TwoSlopeNorm(vmin=-vmax, vcenter=0.0, vmax=vmax)
extent = [0, n2, 0, n1]
fig, axs = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)
im0 = axs[0].imshow(R_low, origin="lower", cmap=cmap, norm=norm, extent=extent, interpolation="nearest")
axs[0].set_title(f"{pop_label} low-effect residuals")
axs[0].set_xlabel(f"{pop_label}23 (j)")
axs[0].set_ylabel(f"{pop_label}02 (i)")
im1 = axs[1].imshow(R_high, origin="lower", cmap=cmap, norm=norm, extent=extent, interpolation="nearest")
axs[1].set_title(f"{pop_label} high-effect residuals")
axs[1].set_xlabel(f"{pop_label}23 (j)")
axs[1].set_ylabel(f"{pop_label}02 (i)")
for ax in axs:
ax.plot([0, n2], [0, n1], color="k", lw=0.9, ls=":")
ax.set_xlim(0, n2)
ax.set_ylim(0, n1)
ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=6))
ax.yaxis.set_major_locator(MaxNLocator(integer=True, nbins=6))
cbar = fig.colorbar(im1, ax=axs, fraction=0.025, pad=0.02)
cbar.set_label("Anscombe residual")
plt.show()
# Compute neutral-corrected row-wise effect sizes
def residual_shift_by_row(obs, model):
O = obs.data.astype(float)
M = model.data.astype(float)
n_y, n_x = O.shape
rows = np.arange(n_y)
mean_j_ratio, E_delta_ratio = [], []
for i in rows:
o = O[i, :]
m = M[i, :]
w = (o + 1.0) / (m + 1.0)
w[~np.isfinite(w)] = 0.0
if w.sum() == 0:
mean_j_ratio.append(np.nan)
E_delta_ratio.append(np.nan)
continue
j = np.arange(n_x)
mean_j_ratio.append((j * w).sum() / w.sum())
E_delta_ratio.append(((j - i) * w).sum() / w.sum())
return pd.DataFrame({"row": rows, "mean_j_ratio": mean_j_ratio, "E_delta_ratio": E_delta_ratio})
# --- Main function to fit neutral model and plot residuals and neutral-corrected effect sizes ---
def compare_residual_shifts(pop_label, data_low, data_high):
assert np.allclose(data_low.sample_sizes, data_high.sample_sizes), f"{pop_label}: low/high SFS must have same projections"
print(f"\nFitting neutral model to {pop_label} low-effect 2dSFS...")
p_opt, model, ll, theta, pts_l = fit_split_mig_sym_on_low(data_low)
print(f"{pop_label} best params [nu1, nu2, T, m]: {np.round(p_opt,4)}, logL={ll:.2f}, theta={theta:.3f}")
plot_residual_heatmaps(data_low, data_high, model_neutral=model, pop_label=pop_label,
cmap="RdBu_r", vclip_percent=96, scale="symlog", linthresh=1.0,
exclude_edges=True, edge_pad=3, corner_crop=2, use_mad=True, sigma=3.5)
df_low = residual_shift_by_row(data_low, model)
df_high = residual_shift_by_row(data_high, model)
df = df_low.merge(df_high, on="row", suffixes=("_low", "_high"))
fig, axs = plt.subplots(1, 2, figsize=(11, 4), constrained_layout=True)
axs[0].plot(df.row, df.mean_j_ratio_low - df.mean_j_ratio_high, 'k-o', lw=1.5, ms=4)
axs[0].axhline(0, color='gray', ls=':')
axs[0].set_xlabel(f"Row ({pop_label}02 i)")
axs[0].set_ylabel("Δ mean j (ratio, low − high)")
axs[0].set_title(f"{pop_label}: neutral-corrected mean {pop_label}23 shift")
axs[1].plot(df.row, df.E_delta_ratio_low - df.E_delta_ratio_high, 'm-o', lw=1.5, ms=4)
axs[1].axhline(0, color='gray', ls=':')
axs[1].set_xlabel(f"Row ({pop_label}02 i)")
axs[1].set_ylabel("Δ E[Δ=j−i] (ratio, low − high)")
axs[1].set_title(f"{pop_label}: neutral-corrected shift relative to diagonal")
plt.show()
return p_opt, model, df
9.2. Call the neutral fits and residual diagnostics:¶
has_params, has_model, has_resid_df = compare_residual_shifts("Has", Has_low, Has_high)
Fitting neutral model to Has low-effect 2dSFS... Has best params [nu1, nu2, T, m]: [1.3810e-01 6.9900e-02 6.4000e-03 4.9897e+01], logL=-7472.98, theta=18768.029
🔟 Subsampling Low-Effect Mutations for Comparison¶
To assess whether high-effect mutations differ significantly from low-effect ones, we:
- Subsampled the low-effect mutations to match the number of high-effect mutations.
- Calculated the mean effect size for each subsample.
- Compared the observed mean of high-effect mutations to the distribution of subsampled low-effect means.
This approach evaluates whether the mean effect of high-impact variants exceeds what would be expected by chance from low-effect variants.
Utility Function: Normalise Row to Probabilities Convert a row of allele counts into probabilities — ensures each row sums to 1.
def normalize_row(c):
c = np.asarray(c, float)
s = c.sum()
return c / s if s > 0 else np.zeros_like(c)
Generate Null Distribution via Subsampling Subsample low-effect mutations to match the count of high-effect ones and compute mean distributions.
def subsample_low_effect_means(sfs_low, sfs_high, row_index, n_reps=2000, rng=None):
"""
For a given row (i = derived allele count in pop1),
- Expand low-effect mutation spectrum into a list of allele counts.
- Subsample the same number as high-effect alleles.
- Compute the mean allele count in pop2 for each replicate.
Returns:
rep_means = distribution of means (null)
mean_of_means = mean of null distribution
"""
if rng is None:
rng = np.random.default_rng(42)
cH = sfs_high.data[row_index].astype(int)
cL = sfs_low.data[row_index].astype(int)
nH, nL = cH.sum(), cL.sum()
if nH == 0 or nL == 0 or nL < nH:
return None, None
alleles = []
for j_value, freq in enumerate(cL):
alleles.extend([j_value] * freq)
rep_means = np.array([
rng.choice(alleles, size=nH, replace=False).mean()
for _ in range(n_reps)
])
return rep_means, rep_means.mean()
Empirical P-Value Calculation Estimate how extreme the observed mean is relative to the null distribution.
def calculate_empirical_p_value(observed_stat, null_dist):
"""
Compute two-sided empirical p-value centered around null mean.
+1 pseudocount avoids zero p-values.
"""
n = len(null_dist)
mu0 = np.mean(null_dist)
extreme = np.sum(np.abs(null_dist - mu0) >= np.abs(observed_stat - mu0))
return (extreme + 1) / (n + 1)
Plot Mean Comparison Visualize observed high-effect means against null distributions from low-effect subsampling.
def plot_mean_comparison(sfs_high, sfs_low,
n_reps=2000,
pop1_name="Has02", pop2_name="Has23",
bins=30):
"""
Create grid panels (one per SFS row) showing:
- Blue: null distribution from low-effect subsamples
- Red: observed high-effect mean
- Dashed blue: null mean
- Title: permutation p-value
"""
n_rows = sfs_high.data.shape[0]
rows = np.arange(1, n_rows)
rng = np.random.default_rng(42)
fig, axes = plt.subplots(6, 4, figsize=(20, 15))
axes = axes.flatten()
fig.suptitle(
f"Permutation Test: High vs. Low Effect Mutations ({pop1_name} → {pop2_name})",
fontsize=20
)
for idx, i in enumerate(rows):
ax = axes[idx]
rep_means, mean_of_means = subsample_low_effect_means(
sfs_low, sfs_high, i, n_reps=n_reps, rng=rng
)
if rep_means is None:
ax.text(0.5, 0.5, f"Row {i}\n(no data)", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
continue
cH = sfs_high.data[i].astype(float)
pH = normalize_row(cH)
j_vals = np.arange(len(pH))
mean_high = float(np.dot(j_vals, pH))
p_perm = calculate_empirical_p_value(mean_high, rep_means)
ax.hist(rep_means, bins=bins, density=True, color="blue", alpha=0.6, edgecolor="black")
ax.axvline(mean_high, color="red", linestyle="--", lw=2, label="High-effect mean")
ax.axvline(mean_of_means, color="darkblue", linestyle=":", lw=2, label="Low-effect mean")
ax.set_title(f"Row {i} | Perm. p={p_perm:.3f}", fontsize=10)
ax.set_xlabel(f"Mean derived allele count in {pop2_name}")
ax.set_ylabel("Density")
if idx == 0:
ax.legend(fontsize=8, loc="upper left", frameon=True)
for j in range(len(rows), len(axes)):
fig.delaxes(axes[j])
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
Example Usage Run the comparison for two timepoints (Has02 → Has23) of the population .
plot_mean_comparison(Has_high, Has_low, n_reps=2000, pop1_name="Has02", pop2_name="Has23")
10.1. Comparison: High vs. Low Effect Mutations¶
compares site frequency spectra (SFS) between high- and low-effect mutations
using subsampling, permutation tests, Jensen–Shannon divergence, and χ² statistics.
Normalize a Row of Counts Convert a row of allele counts into probabilities that sum to 1.
# ----------------------------
# Normalize a count vector into probabilities
# ----------------------------
# === Take a row of SFS counts and turns it into a probability distribution ===
def normalize_row(c):
c = np.asarray(c, float)
s = c.sum()
return c / s if s > 0 else np.zeros_like(c)
# === Build a null distribution of mean allele counts by repeatedly subsampling low-effect alleles to match the size of the high-effect sample ===
def subsample_low_effect_means(sfs_low, sfs_high, row_index, n_reps=2000, rng=None):
# Set random number generator
if rng is None:
rng = np.random.default_rng(42)
# Extract row counts for high- and low-effect mutations
cH = sfs_high.data[row_index].astype(int)
cL = sfs_low.data[row_index].astype(int)
# Total counts in that row
nH, nL = cH.sum(), cL.sum()
# Skip if too few counts
if nH == 0 or nL == 0 or nL < nH:
return None, None
# Expand low-effect SFS row into individual allele counts
alleles = []
for count_val, freq in enumerate(cL):
if freq > 0:
alleles.extend([count_val] * int(freq))
# Subsample repeatedly to match high-effect sample size
rep_means = np.empty(n_reps, dtype=float)
for r in range(n_reps):
samp = rng.choice(alleles, size=nH, replace=False)
rep_means[r] = samp.mean()
return rep_means, float(rep_means.mean())
# === Compute two-sided permutation p-value for whether the high-effect mean is unusual compared to the null ===
def empirical_p_value_mean(observed_mean, null_means):
# Compare observed mean to distribution of null means
center = np.mean(null_means)
obs_dev = abs(observed_mean - center)
null_devs = np.abs(null_means - center)
n = len(null_means)
extreme = np.sum(null_devs >= obs_dev)
return (extreme + 1) / (n + 1) # add +1 pseudocount
# === Return two metrics: JSD (continuous measure of divergence between SFS shapes). Chi² p-value (significance of shape differences) ===
def jsd_and_chi2_from_sfs(cH, cL):
# Ignore positions where both are zero
mask = (cH + cL) > 0
cH_ = cH[mask].astype(float)
cL_ = cL[mask].astype(float)
# Normalize to probability distributions
pH = normalize_row(cH_)
pL = normalize_row(cL_)
# Jensen-Shannon divergence (distribution similarity)
try:
jsd = float(jensenshannon(pH, pL, base=2.0))
except Exception:
jsd = np.nan
# Chi-square test (compare raw counts)
try:
table = np.vstack([cH_, cL_])
chi2, p_chi, dof, expected = chi2_contingency(table, correction=False)
p_chi = float(p_chi)
except Exception:
p_chi = np.nan
return jsd, p_chi
# === Prepare figure with multiple subplots (each row of SFS = one subplot) ===
def plot_comprehensive_comparison(sfs_high, sfs_low,
n_reps=2000,
pop1_name="Has02", pop2_name="Has23",
bins_low=10):
# Number of rows in SFS (pop1 allele counts)
n_rows = sfs_high.data.shape[0]
rows = np.arange(1, n_rows) # skip trivial row 0
rng = np.random.default_rng(42)
# Create subplot grid
fig, axes = plt.subplots(6, 4, figsize=(24, 18))
fig.suptitle(
f"Row-wise SFS Shape and Mean Comparison: High vs Low-effect ({pop1_name} → {pop2_name})",
fontsize=22, y=0.995
)
axes = axes.flatten()
# === For each row, build the null distribution of low-effect means --- #
for idx, i in enumerate(rows):
ax = axes[idx]
# Get null distribution of subsampled low-effect means
rep_means, mean_of_means = subsample_low_effect_means(
sfs_low, sfs_high, i, n_reps=n_reps, rng=rng
)
if rep_means is None:
ax.text(0.5, 0.5, f"Row {i}\n(no data)", ha='center', va='center',
transform=ax.transAxes)
ax.set_axis_off()
continue
# === Prepare the two distributions, calculates means, JSD, Chi², and permutation p-value ===
# Extract SFS counts
cH = sfs_high.data[i].astype(float)
cL = sfs_low.data[i].astype(float)
# Normalize to probability distributions
pH = normalize_row(cH)
pL = normalize_row(cL)
x = np.arange(len(pH))
# Mean of high-effect SFS
mean_high = float(np.dot(x, pH))
# Compute statistics
jsd, p_chi = jsd_and_chi2_from_sfs(cH, cL)
p_perm = empirical_p_value_mean(mean_high, rep_means)
# === Plot red = high-effect SFS, blue = low-effect SFS; histogram = subsampled low-effect means ===
# Plot SFS distributions
ax.plot(x, pH, color='red', lw=2, label='High-effect SFS')
ax.plot(x, pL, color='blue', lw=2, linestyle=':', label='Low-effect SFS')
ax.axvline(mean_high, color='red', linestyle='--', lw=2, label='High-effect mean')
# Second y-axis for histogram of subsampled low-effect means
ax2 = ax.twinx()
ax2.hist(rep_means, bins=bins_low, density=True, color='blue',
alpha=0.15, edgecolor='black', linewidth=0.5,
label=f'Subsampled means ({n_reps} reps)')
ax2.axvline(mean_of_means, color='blue', linestyle='--', lw=2,
label='Mean of subsampled means')
# === Subplot title show all test results for that row ===
# Annotate with results
ax.set_title(
f"Row {i} | JSD={jsd:.3f}, χ² p={p_chi:.2e} (compare SFS shape)\n"
f"Permutation p={p_perm:.3f} (compare means)",
fontsize=12
)
# === Clean up figure and display ===
for j in range(len(rows), len(axes)):
fig.delaxes(axes[j]) # remove unused subplots
plt.tight_layout(rect=[0, 0.02, 1, 0.96])
plt.show()
10.2. Example usage¶
plot_comprehensive_comparison(Has_high, Has_low, n_reps=2000,
pop1_name="Has02", pop2_name="Has23")
1️⃣1️⃣ Discussion, interpretation & exercises¶
Questions to think about
- Which effect classes show an excess of rare derived alleles in the recent timepoint (Has23)?
- Where do model residuals show strong deviations from the neutral expectation?
- How would selection vs migration shape the residual patterns?
Exercises
- Change
projectionsto smaller/larger values and observe smoothing of the 2dSFS. - Increase bootstrap
Band inspect CI stability. - Modify the neutral model (allow asymmetric migration) and compare fits.
- Repeat the pipeline for another population or time series.