bayesianbandits.api

Agent API (bayesianbandits.api)

Fully typed API for Bayesian bandits.

This module contains the API for Bayesian bandits. It passes strict type checking with pyright and is recommended for use in production code. On top of the type checking, this API makes it much easier to add or remove arms, change the policy function, and serialize/deserialize bandits in live services.

It uses the same estimators and Arm class as the original API, but defines its own ContextualAgent and Agent classes, as well as a Policy type alias for the policy functions.

This API splits the Bandit class into two classes, ContextualAgent and Agent. This split enables safer typing, as the contextual bandit always takes a context matrix as input, while the non-contextual bandit does not.

Additionally, this API deprecates the delayed_reward decorator, as it modifies the function signatures of the pull and update methods. Instead, this API enables batch pulls and updates, but leaves it up to the user to keep track of matching updates with the correct pulls. Library users reported that this was what they were doing anyway, so this API change should encourage better practices.

Note

Migrating from the original API to this API should be straightforward. Just instantiate the ContextualAgent or MultiArmedBandit class with list(arms.values()) and the policy function of the original Bandit subclass.

Bandit Classes

ContextualAgent(arms, policy[, random_seed])

Agent for a contextual multi-armed bandit problem.

Agent(arms, policy[, random_seed])

Agent for a non-contextual multi-armed bandit problem.

Policy Functions

EpsilonGreedy([epsilon, samples])

Policy object for epsilon-greedy.

ThompsonSampling()

Policy object for Thompson sampling.

UpperConfidenceBound([alpha, samples])

Policy object for upper confidence bound.

Classes

Agent(arms, policy[, random_seed])

Agent for a non-contextual multi-armed bandit problem.

ContextualAgent(arms, policy[, random_seed])

Agent for a contextual multi-armed bandit problem.

EpsilonGreedy([epsilon, samples])

Policy object for epsilon-greedy.

PolicyProtocol(*args, **kwargs)

ThompsonSampling()

Policy object for Thompson sampling.

UpperConfidenceBound([alpha, samples])

Policy object for upper confidence bound.