EmpiricalBayesDirichletClassifier#
Tunes the Dirichlet prior concentration parameters via Minka’s fixed-point iteration [1] for the Dirichlet-Multinomial marginal likelihood, with stabilized forgetting to prevent prior collapse under exponential decay.
Builds on the posterior update from Intercept-Only Models, which is inherited unchanged. Read that page first.
Symbols#
Symbol |
Meaning |
|---|---|
\(\boldsymbol{\alpha}\) |
Dirichlet prior concentration parameters, length \(K\) |
\(s\) |
Scalar concentration \(s = \sum_k \alpha_k\) |
\(\mathbf{m}\) |
Base measure \(m_k = \alpha_k / s\), \(\sum_k m_k = 1\) |
\(c_{gk}\) |
Effective count of class \(k\) in group \(g\) |
\(c_g\) |
Total effective counts in group \(g\): \(c_g = \sum_k c_{gk}\) |
\(G\) |
Number of groups (unique values of the first feature) |
\(K\) |
Number of classes |
\(\psi\) |
Digamma function |
Dirichlet-Multinomial marginal likelihood#
The log marginal likelihood (evidence) for \(G\) groups with effective counts \(c_{gk}\) under prior \(\boldsymbol{\alpha}\):
(s, m) decomposition#
The prior is decomposed as \(\alpha_k = s \cdot m_k\) where \(s\) controls overall prior strength and \(\mathbf{m}\) controls prior shape. Optimizing \(s\) and \(\mathbf{m}\) in alternation converges faster than optimizing \(\boldsymbol{\alpha}\) directly [1].
m update (base measure)#
Fixing \(s\), the per-component update is:
Then renormalize: \(m_k^{\text{new}} = \alpha_k^{\text{new}} / \sum_j \alpha_j^{\text{new}}\).
s update (concentration)#
Fixing \(\mathbf{m}\), the scalar update is:
Each (m, s) step is a lower-bound maximization, so the log marginal likelihood is non-decreasing across iterations.
Counts extraction#
The effective counts are derived from the difference between posterior and prior concentrations:
With stabilized forgetting (see below), this gives the decayed counts exactly.
Stabilized forgetting#
The problem. Uniform decay \(\boldsymbol{\alpha}_g \leftarrow \gamma \, \boldsymbol{\alpha}_g\) drives all concentrations to zero.
The solution. Every decay (both explicit via decay() and
implicit via learning_rate in partial_fit) re-injects the
EB-tuned prior:
where \(n\) is the number of observations. This maintains the invariant:
so the effective counts are correctly decayed and the prior contribution never vanishes.
fit vs. partial_fit#
fit iterates the Minka step to convergence on fixed data (up to
n_eb_iter times, checking
\(|\Delta \log p| < \texttt{eb\_tol}\)), then refits the
posterior with the converged prior. partial_fit runs one step per
call, correcting all group posteriors by the prior change
\(\boldsymbol{\alpha}^{\text{new}} -
\boldsymbol{\alpha}^{\text{old}}\).
Hyperparameter semantics#
Parameter |
Controls |
Practical guidance |
|---|---|---|
|
Initial prior concentrations. EB tunes these. |
Start uniform (e.g. |
|
Maximum EB iterations during |
10 (default). Set to 0 to disable EB during |
|
Convergence tolerance on log evidence change. |
1e-4 (default). |
|
Decay factor \(\gamma\). See Choosing and Tuning a Decay Rate. |
1.0 (default) for stationary environments. |
Robustness to misspecified priors#
The base measure \(\mathbf{m}\) converges to the true shape regardless of initialization. The scalar concentration \(s\) converges more slowly (a known property of Dirichlet MLE), but \(\mathbf{m}\) determines the relative class weights, which is what matters for predictions.
import numpy as np
from bayesianbandits import EmpiricalBayesDirichletClassifier
rng = np.random.default_rng(42)
true_alpha = np.array([3.0, 1.0]) # 75/25 split
# Wrong initial prior: heavily favors class 2
clf = EmpiricalBayesDirichletClassifier(
{1: 1.0, 2: 5.0}, random_state=0
)
for g in range(200):
theta = rng.dirichlet(true_alpha)
obs = rng.choice([1, 2], size=rng.poisson(10) + 1, p=theta)
clf.partial_fit(np.full((len(obs), 1), g), obs)
# Recovered base measure: m ≈ [0.75, 0.25]
s = sum(clf.alphas.values())
m = {k: v / s for k, v in clf.alphas.items()}