bayesianbandits.ContextualAgent#

class bayesianbandits.ContextualAgent(arms: Sequence[Arm[ContextType, TokenType]], policy: PolicyProtocol[ContextType, TokenType], random_seed: int | None | Generator = None)#

Bases: Generic[ContextType, TokenType]

Agent for a contextual multi-armed bandit problem.

At each round the agent observes a context \(x_t\), selects an arm \(a_t\) according to the configured policy, and later receives a reward \(r_t\). Each arm maintains an independent Bayesian learner, so the posterior for arm \(a\) is updated only when that arm is selected:

\[a_t = \pi\bigl(\{p(\theta_a \mid \mathcal{D}_a)\}_{a=1}^{K}, \; x_t\bigr), \qquad \mathcal{D}_{a_t} \leftarrow \mathcal{D}_{a_t} \cup \{(x_t, r_t)\}\]

where \(\pi\) is the policy (e.g. Thompson sampling, UCB, or \(\varepsilon\)-greedy) and \(K\) is the number of arms.

Parameters:
  • arms (Sequence[Arm[ContextType, TokenType]]) – Arms to choose from. Each arm must carry a fitted or unfitted learner and a unique action token that identifies it.

  • policy (PolicyProtocol[ContextType, TokenType]) – Policy object that implements arm selection given posteriors and context. Built-in options: ThompsonSampling, UpperConfidenceBound, EpsilonGreedy.

  • random_seed (int, np.random.Generator, or None, default None) – Controls the random number generator shared by the policy and all learners. Pass an int for reproducible results across calls.

See also

Agent

Non-contextual (intercept-only) agent.

LipschitzContextualAgent

Shared-learner agent with configurable design matrix; generalizes this agent.

Notes

Independent learners. Each arm owns a separate learner instance that is updated only with observations from that arm. This is the standard approach when arms have independent reward distributions [1]. For parameter sharing across arms, see LipschitzContextualAgent.

Batch contexts. Both pull() and update() accept matrices with multiple rows, producing one decision or update per row. This enables efficient batch serving but requires the user to match rewards to the correct arms when using delayed feedback [2].

Serialization. The agent (including all arm learners) is pickle-compatible, making it straightforward to persist to a database or message queue for use in live services.

References

Examples

Create an agent with two arms and pull for a single context:

>>> from bayesianbandits import Arm, NormalInverseGammaRegressor
>>> from bayesianbandits import ContextualAgent, ThompsonSampling
>>> arms = [
...     Arm(0, learner=NormalInverseGammaRegressor()),
...     Arm(1, learner=NormalInverseGammaRegressor()),
... ]
>>> agent = ContextualAgent(arms, ThompsonSampling(), random_seed=0)

The pull method takes a context matrix and returns one action token per row:

>>> import numpy as np
>>> X = np.array([[1.0, 15.0]])
>>> agent.pull(X)
[1]

By default the last pulled arm is queued for update. Call update with the same context and the observed reward:

>>> y = np.array([100.0])
>>> agent.update(X, y)
>>> agent.arm_to_update.learner.predict(X)
array([99.55947137])

For delayed rewards, explicitly select which arm to update using the fluent select_for_update() interface:

>>> agent.select_for_update(1).update(X, y)
>>> agent.arm_to_update is arms[1]
True
>>> agent.arm_to_update.learner.predict(X)
array([99.77924945])
__init__(arms: Sequence[Arm[ContextType, TokenType]], policy: PolicyProtocol[ContextType, TokenType], random_seed: int | None | Generator = None)#
add_arm(arm: Arm[Any, TokenType]) None#

Add an arm to the bandit.

Parameters:

arm (Arm) – Arm to add to the bandit.

Raises:

ValueError – If the arm’s action token is already in the bandit.

arm(token: TokenType) Arm[Any, TokenType]#

Get an arm by its action token.

Parameters:

token (TokenType) – Action token of the arm to get.

Returns:

Arm with the action token.

Return type:

Arm

Raises:

KeyError – If the arm’s action token is not in the bandit.

property arms: List[Arm[ContextType, TokenType]]#
decay(X: ContextType, decay_rate: float | None = None) None#

Decay all arms of the bandit len(X) times.

Parameters:
  • X (ContextType) – Context matrix to use for decaying the arm.

  • decay_rate (Optional[float], default None) – Decay rate to use for decaying the arm. If None, the decay rate of the arm’s learner is used.

pull(X: ContextType) List[TokenType]#
pull(X: ContextType, *, top_k: int) List[List[TokenType]]

Choose arm(s) and pull based on the context(s).

Parameters:
  • X (ContextType) – Context matrix to use for choosing arms.

  • top_k (int, optional) – Number of arms to select per context. If None (default), selects single best arm per context. If specified, selects top k arms per context.

Returns:

If top_k is None: List of action tokens (one per context) If top_k is int: List of lists of action tokens

Return type:

List[TokenType] or List[List[TokenType]]

Notes

When top_k is None, arm_to_update is set to the last selected arm. When top_k is specified, arm_to_update is NOT updated - you must explicitly call select_for_update() before update() to specify which arm’s feedback you’re providing.

remove_arm(token: TokenType) None#

Remove an arm from the bandit.

Parameters:

token (TokenType) – Action token of the arm to remove.

Raises:

KeyError – If the arm’s action token is not in the bandit.

property rng: Generator#
select_for_update(token: TokenType) Self#

Set the arm_to_update and return self for chaining.

Parameters:

token (Any) – Action token of the arm to update.

Returns:

Self for chaining.

Return type:

Self

Raises:

KeyError – If the arm’s action token is not in the bandit.

update(X: ContextType, y: ndarray[tuple[int, ...], dtype[float64]], sample_weight: ndarray[tuple[int, ...], dtype[float64]] | None = None) None#

Update the arm_to_update with the context(s) and the reward(s).

Parameters:
  • X (ContextType) – Context matrix to use for updating the arm.

  • y (NDArray[np.float64]) – Reward(s) to use for updating the arm.

  • sample_weight (Optional[NDArray[np.float64]], default None) – Sample weights to use for updating the arm. If None, all samples are weighted equally.