bayesianbandits.EmpiricalBayesDirichletClassifier#
- class bayesianbandits.EmpiricalBayesDirichletClassifier(alphas: dict[int | str, float], *, n_eb_iter: int = 10, eb_tol: float = 0.0001, learning_rate: float = 1.0, random_state: int | None | Generator = None)#
Bases:
DirichletClassifierDirichlet-Multinomial classifier with empirical Bayes prior tuning.
Extends
DirichletClassifierwith automatic optimization of the Dirichlet prior concentration parameters via Minka’s fixed-point iteration [1] for the Dirichlet-Multinomial marginal likelihood.During
fit, the prior is iteratively updated to maximize the marginal likelihood across all groups. Duringpartial_fit, a single Minka step is performed using the current group posteriors.The prior is decomposed as
α_k = s · m_kwheresis the scalar concentration (prior strength) andm_kis the base measure (prior shape). Themandsupdates are alternated for faster convergence.When
learning_rate < 1, exponential forgetting is applied. Stabilized forgetting re-injects the EB-tuned prior after each decay step, ensuring the prior contribution converges to the tuned value rather than zero.- Parameters:
alphas (
dictof{int or str: float}) – Initial prior concentration parameters for each class.n_eb_iter (
int, default10) – Maximum number of EB iterations duringfit.eb_tol (
float, default1e-4) – Convergence tolerance on change in log marginal likelihood.learning_rate (
float, default1.0) – Decay rate for sequential updates.random_state (
int,np.random.Generator, orNone, defaultNone) – Controls RNG forsample.
- log_evidence_#
Log marginal likelihood at convergence.
- Type:
float
- n_eb_iterations_#
Number of EB iterations in last
fit.- Type:
int
- eb_converged_#
Whether the EB loop converged within
eb_tol.- Type:
bool
See also
DirichletClassifierBase estimator without empirical Bayes tuning.
EmpiricalBayesNormalRegressorEB tuning for Normal regression.
Notes
Minka’s fixed-point iteration
Given G groups with effective counts
c_gk(derived from the difference between posterior and prior concentrations), the per-component update is [1]:\[\alpha_k^{\text{new}} = \alpha_k \cdot \frac{\sum_g [\psi(c_{gk} + \alpha_k) - \psi(\alpha_k)]} {\sum_g [\psi(c_g + \alpha_0) - \psi(\alpha_0)]}\]where \(\psi\) is the digamma function, \(\alpha_0 = \sum_k \alpha_k\), and \(c_g = \sum_k c_{gk}\).
Stabilized forgetting
Every time
learning_rate\(\gamma < 1\) causes decay (both inpartial_fitand indecay), the prior is re-injected:\[\boldsymbol{\alpha}_g \leftarrow \gamma^n \boldsymbol{\alpha}_g + (1 - \gamma^n) \boldsymbol{\alpha}^{\text{prior}}\]This ensures effective counts are \(\gamma^n \cdot \text{counts}\) and the prior contribution never vanishes.
References
Examples
Basic classification with EB-tuned prior:
>>> import numpy as np >>> from bayesianbandits import EmpiricalBayesDirichletClassifier >>> X = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]).reshape(-1, 1) >>> y = np.array([1, 1, 2, 2, 2, 1, 1, 1, 2]) >>> clf = EmpiricalBayesDirichletClassifier( ... {1: 1, 2: 1}, random_state=0 ... ) >>> clf.fit(X, y) EmpiricalBayesDirichletClassifier(alphas={1: ..., 2: ...}, random_state=0)
Misspecified priors converge to the true shape. Here the true DGP is
Dir(3, 1)(75/25 split) but the initial prior heavily favors class 2:>>> rng = np.random.default_rng(42) >>> true_alpha = np.array([3.0, 1.0]) >>> clf = EmpiricalBayesDirichletClassifier( ... {1: 1.0, 2: 5.0}, random_state=0 ... ) >>> for g in range(200): ... theta = rng.dirichlet(true_alpha) ... obs = rng.choice([1, 2], size=rng.poisson(10) + 1, p=theta) ... clf.partial_fit(np.full((len(obs), 1), g), obs) >>> # Base measure recovers true shape: m ≈ [0.75, 0.25]
- __init__(alphas: dict[int | str, float], *, n_eb_iter: int = 10, eb_tol: float = 0.0001, learning_rate: float = 1.0, random_state: int | None | Generator = None) None#
- decay(X: NDArray[Any], *, decay_rate: float | None = None) None#
Decay with stabilized prior re-injection.
Applies exponential forgetting and re-injects the EB-tuned prior so that the prior contribution converges to
prior_rather than zero. This ensures that effective counts (known_alphas - prior) remain correct after decay.- Parameters:
X (
array-likeofshape (n_samples,1)) – Used to identify which groups to decay.decay_rate (
float, defaultNone) – Decay factor in (0, 1]. If None, usesself.learning_rate.
- fit(X: NDArray[Any], y: NDArray[Any], sample_weight: NDArray[Any] | None = None) Self#
Fit with empirical Bayes hyperparameter tuning.
Iteratively optimizes the Dirichlet prior concentration parameters by maximizing the Dirichlet-Multinomial marginal likelihood using Minka’s fixed-point iteration, then performs a final posterior update with the converged prior.
- Parameters:
X (
array-likeofshape (n_samples,1)) – Training data.y (
array-likeofshape (n_samples,)) – Class labels.sample_weight (
array-likeofshape (n_samples,), optional) – Individual weights for each sample.
- Returns:
self – Fitted estimator with tuned prior.
- Return type:
- partial_fit(X: NDArray[Any], y: NDArray[Any], sample_weight: NDArray[Any] | None = None) Self#
Incrementally update and retune hyperparameters.
Performs a posterior update (via the base class), then runs one Minka fixed-point step to adjust the prior. All group posteriors are corrected to reflect the new prior.
- Parameters:
X (
array-likeofshape (n_samples,1)) – Training data.y (
array-likeofshape (n_samples,)) – Class labels.sample_weight (
array-likeofshape (n_samples,), optional) – Individual weights for each sample.
- Returns:
self – Updated estimator with retuned prior.
- Return type: