A Simple Industrial Example: Real-Time Bidding

A jupyter notebook investigating how to apply reinforcement learning to the industrial problem of real-time bidding.

A Simple Industrial Example: Real-Time Bidding

Contrary to sponsored advertising, where advertisers set fixed bids, in real-time bid‐ ding (RTB) you can set a bid for every individual impression. When a user visits a website that supports ads, this triggers an auction where advertisers bid for an impres‐ sion. Advertisers must submit bids within a period of time, where 100 ms is a com‐ mon limit.

The advertising platform provides contextual and behavioral information to the advertiser for evaluation. The advertiser uses an automated algorithm to decide how much to bid based upon the context. In the long-term, the platform’s products must deliver a satisfying experience, to maintain the advertising revenue stream. But adver‐ tisers want to maximize some key performance indicator (KPI), for example, the number of impressions or click through rate (CTR), for the least cost.

RTB presents a clear action (the bid), state (the information provided by the plat‐ form) and agent (the bidding algorithm). Both platforms and advertisers can use RL to optimize for their definition of reward.

To quickly demonstrate this idea, below I present some code to simulate a bidding environment. First let me install the dependencies.

A note on usage

Note that this notebook might not work on your machine because simple_rl forces TkAgg on some machines. See https://github.com/david-abel/simple_rl/issues/40

Also, Pygame is notoriously picky and expects loads of compiler/system related libraries.

I managed to get this working on the following notebook:

docker run -it -p 8888:8888 jupyter/scipy-notebook:54462805efcb

This code is untested on any other notebook.

TODO: migrate away from simple rl and pygame. TODO: Create dedicated q-learning and sarsa notebooks.

!pip install pygame==1.9.6 pandas==1.0.5 matplotlib==3.2.1 gym==0.17.3 gym-display-advertising==0.0.1 > /dev/null
!pip install --upgrade git+git://github.com/david-abel/simple_rl.git@77c0d6b910efbe8bdd5f4f87337c5bc4aed0d79c > /dev/null
import matplotlib
matplotlib.use("agg", force=True)
  Running command git clone -q git://github.com/david-abel/simple_rl.git /tmp/pip-req-build-g820zz6v

Global Defines

First is a small set of globals that set some of the options for the algorithms. Should be self explanatory.

eps = 0.05
gam = 0.99
alph = 0.0001
n_episodes = 200

Helper Functions

This section contains quite a lot of code. These are the helper functions to perform various aspects of the problem. For example, I need to discretize the state space so that it works with the tabular value-based algorithms.

Then there is some helper code to perform the training iteration. Loops in loops.

import numpy as np

def state_mapping(obs):
    """
    Since this is tabular, we can't use real numbers. There would be an infinite
    number of states. Instead I round and convert to an integer. This is a
    simple form of _tile coding_.
    """
    return tuple(np.round(100 * obs[0:1]))

def run_episode(env, agent, learning=True):
    episode_reward = 0
    observation = env.reset()
    episode_over = False
    reward = 0
    action_buffer = []
    while not episode_over:
        if hasattr(agent, "q_func"):
            action = agent.act(
                state_mapping(observation),
                reward,
                learning=learning)
        else:
            action = agent.act(
                state_mapping(observation),
                reward)
        action_buffer.append(observation[0])
        observation, reward, episode_over, _ = env.step(action)
        episode_reward += reward
    agent.end_of_episode()
    return episode_reward, action_buffer

def train_agent(env, agent_func, n_repeats):
    train_rewards_buffer = np.zeros((n_episodes, n_repeats))
    train_bid_buffer = np.zeros((n_episodes, n_repeats, env.batch_size))
    for instance in range(n_repeats):
        agent = agent_func(range(env.action_space.n))
        for episode in range(n_episodes):
            episode_reward, action_buffer = run_episode(env, agent)
            train_rewards_buffer[episode, instance] = episode_reward
            train_bid_buffer[episode, instance, :] = np.pad(
                action_buffer, (0, env.batch_size - len(action_buffer)),
                'constant', constant_values=0)

    if hasattr(agent, "q_func"):
        print_arbitrary_policy(agent.q_func)

    # Test
    agent.epsilon = 0
    episode_reward, test_bid_buffer = run_episode(env, agent, learning=False)
    print(
        "{}: {}".format(
            "TEST",
            episode_reward))
    print(train_rewards_buffer.transpose())
    print(test_bid_buffer)
    return train_rewards_buffer, train_bid_buffer.mean(axis=2)

Agents

Below I define the agents, again using simple_rl.

from simple_rl.agents import RandomAgent, DelayedQAgent, DoubleQAgent, QLearningAgent

def q_agent(actions):
    return QLearningAgent(
        actions,
        gamma=gam,
        epsilon=eps,
        alpha=alph,
    )

def random_agent(actions):
    return RandomAgent(actions)


def sarsa_agent(actions):
    return SARSAAgent(actions, 999, gamma=gam, epsilon=eps, alpha=alph, )
Warning: Tensorflow not installed.
import pandas as pd

def average(data):
    return pd.DataFrame(data.mean(axis=1))

def save(df, path):
    df.to_json(path)

def print_arbitrary_policy(Q):
    for state in sorted(Q.keys(), key=lambda x: (x is None, x)):
        values = Q[state].items()
        print("{}: {}".format(state, sorted(values)))

Running the Experiment

Finally I’m going to run the experiments. I also print a lot of debugging so you can see the raw Q-values. I encourage you to inspect these, and investigate how this changes through learning.

import gym
import gym_display_advertising

name = "bidding_rl_delta_q_learning"
env_name = "StaticDisplayAdvertising-v0"
num_repeats = 10
agent = q_agent
print("Starting {}".format(name))
env = gym.make(env_name)
rewards_buffer, bid_buffer = train_agent(env, agent, num_repeats)
save(average(rewards_buffer), name + ".json")
save(average(bid_buffer), name + "_bid.json")
print("Stopping {}".format(name))
Starting bidding_rl_delta_q_learning
(0.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0)]
(1.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0)]
(2.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 4.242504202164232e-23), (6, 0.0)]
(3.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 4.2933234764231805e-19)]
(4.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 1.772487965584318e-17)]
(5.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 1.7705427281303726e-14)]
(6.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 1.8529846316373888e-13), (6, 0.0)]
(7.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 1.0510479070676586e-09)]
(8.0,): [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0), (4, 0), (5, 0.0), (6, 2.172726597811258e-09)]
(9.0,): [(0, 0.0), (1, 5.482986220213001e-14), (2, 0.0), (3, 0), (4, 0), (5, 0.0), (6, 0)]
(10.0,): [(0, 0.0), (1, 0), (2, 0.0), (3, 2.0510620725986405e-10), (4, 1.353739784534251e-10), (5, 9.666476344842309e-14), (6, 3.5228978855195713e-06)]
(11.0,): [(0, 0.0), (1, 1.2786960520559075e-09), (2, 2.3298705463739654e-10), (3, 0.0), (4, 0.0), (5, 3.6547557057673268e-09), (6, 0.0)]
(12.0,): [(0, 1.9476574163515049e-19), (1, 0.0), (2, 0.0), (3, 0), (4, 0), (5, 0.0), (6, 8.142038370515376e-06)]
(13.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 2.0251998549689234e-05)]
(14.0,): [(0, 0.0), (1, 7.626101192585984e-09), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.0020515585317716803)]
(15.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 5.960236557456779e-11), (5, 0), (6, 0.004230750675800596)]
(16.0,): [(0, 0), (1, 1.2481636164728035e-06), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(17.0,): [(0, 0), (1, 2.9761825088658193e-12), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.0029804345086767493)]
(18.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 1.463837626597533e-06), (5, 0), (6, 0.003275058306453641)]
(19.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.00521805728006848), (5, 0), (6, 0)]
(20.0,): [(0, 4.477270619984004e-10), (1, 1.7642889740359994e-11), (2, 7.922017898594021e-08), (3, 0), (4, 0.010122313476386771), (5, 0), (6, 0)]
(21.0,): [(0, 3.259999349262237e-13), (1, 3.3739862891922365e-07), (2, 0), (3, 0.00010044982208921056), (4, 0.00010146830758872914), (5, 0.03909301596102957), (6, 0.00010045497397094526)]
(22.0,): [(0, 1.2232799134732166e-14), (1, 6.051420612839036e-07), (2, 0.0003047164627964476), (3, 1.2106312120994877e-06), (4, 0.00019364974423658043), (5, 0.02250500317268519), (6, 0)]
(23.0,): [(0, 7.807943887659747e-09), (1, 0.003969841217733571), (2, 0.004345326691962774), (3, 0.6047474755920988), (4, 0.003941487213921284), (5, 0.004772772250477408), (6, 0.0037113460954242283)]
(24.0,): [(0, 0), (1, 0.03136436008885931), (2, 0.0005158851036657193), (3, 0.00010407500666456972), (4, 0.00010014889592816461), (5, 0.0005701547066181823), (6, 0.00030055503865361827)]
(25.0,): [(0, 4.824067503902554e-10), (1, 0), (2, 0), (3, 0.00010098656169669431), (4, 0), (5, 0.010075164933337087), (6, 0.0001006236578970593)]
(26.0,): [(0, 2.9631391908283284e-08), (1, 0.0029916011595607197), (2, 0.0030396607284266432), (3, 0.2695398129016886), (4, 0.0016529380430939447), (5, 0.002259901472329447), (6, 0.0023038675221761507)]
(27.0,): [(0, 0), (1, 0.00010287784946467247), (2, 0.013969945177847403), (3, 0), (4, 0), (5, 0), (6, 0)]
(28.0,): [(0, 1.0729875481445314e-06), (1, 0.0007035103786949147), (2, 0.0010571398403581302), (3, 0.11521207331252095), (4, 0.0006035333103714293), (5, 0.0005025422460264521), (6, 0.0011998239489186856)]
(29.0,): [(0, 6.037928531643457e-08), (1, 0.0001061553383822366), (2, 0.00010508700265439482), (3, 0), (4, 0.00010323677873689857), (5, 0.007507574143751669), (6, 0.00010001979998058969)]
(30.0,): [(0, 1.4716146049329251e-06), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(31.0,): [(0, 3.1561492795990113e-07), (1, 0.00010160361003385946), (2, 0.00010000004075494548), (3, 0.02049682000209946), (4, 0.00010014025481895665), (5, 0.0004021757485606404), (6, 0)]
(32.0,): [(0, 1.0650065993861843e-10), (1, 0), (2, 0), (3, 0), (4, 0.00788617410256639), (5, 2.771418633225326e-07), (6, 0.00010004024555876554)]
(33.0,): [(0, 8.151549667265099e-08), (1, 0), (2, 0.00010028691317591659), (3, 0), (4, 0), (5, 0.00010106910404517729), (6, 0.0024498980467474864)]
(34.0,): [(0, 0), (1, 0), (2, 1.3050876692978922e-06), (3, 0), (4, 0.007101677896291413), (5, 0.00010010918048742164), (6, 0)]
(35.0,): [(0, 2.451120704908705e-07), (1, 0), (2, 0.00010017029483980501), (3, 0), (4, 0.00010049496510376419), (5, 0.008490158149654823), (6, 0)]
(36.0,): [(0, 6.567692121006412e-11), (1, 0), (2, 0), (3, 0), (4, 0.00010008937965828845), (5, 0), (6, 0.005123266777145842)]
(37.0,): [(0, 5.810036548531017e-07), (1, 1.5838306332200946e-07), (2, 2.385479154410431e-07), (3, 0.016197757106961797), (4, 0.0003013076659964926), (5, 0), (6, 0.00030476006015034077)]
(38.0,): [(0, 0), (1, 0), (2, 0.003810268339562596), (3, 0), (4, 0), (5, 0.000200039497980204), (6, 0)]
(39.0,): [(0, 2.3790790932105233e-07), (1, 0.00010029693904353532), (2, 1.514513143407515e-06), (3, 0.00020045512693988513), (4, 5.921706847117824e-12), (5, 0.00020005936647055013), (6, 0.007492606919845531)]
(40.0,): [(0, 0), (1, 0.00020067625715711592), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(41.0,): [(0, 0.0003065146984979159), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0.00010001980491482094), (6, 0)]
(42.0,): [(0, 0), (1, 0.0011029264320872303), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.0001)]
(43.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.00039999933126617975), (5, 0), (6, 0)]
(44.0,): [(0, 0), (1, 0), (2, 0), (3, 0.0), (4, 0), (5, 0), (6, 0.0)]
(45.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.0005003449929652591), (5, 0), (6, 0)]
(46.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.00029995962552051086)]
(47.0,): [(0, 0.0013086929853777316), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(48.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.0005089147848249216), (5, 0), (6, 0)]
(49.0,): [(0, 0), (1, 0), (2, 0.0003000831602492754), (3, 0), (4, 0), (5, 0), (6, 0)]
(50.0,): [(0, 0.000201748139585933), (1, 0.00010008892903573365), (2, 0.00010012270524045234), (3, 0.02968803751179277), (4, 0.00021000374203084465), (5, 0.0004100313029946266), (6, 0.00029996951469281194)]
(51.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.0003093607229629698), (5, 0), (6, 0)]
(52.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.00010463146705253542), (5, 0), (6, 0)]
(53.0,): [(0, 0.00013934022208115043), (1, 0.0002001185046602837), (2, 0.0002052948401226299), (3, 0.027891724665681707), (4, 1.801305697954006e-06), (5, 0.0003015844501583491), (6, 0.000199989799990201)]
(54.0,): [(0, 0.00301959714952006), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0.0001), (6, 0.0001)]
(55.0,): [(0, 0.000201518959821658), (1, 0.00010333514579968351), (2, 5.488166800135822e-11), (3, 0.026391245788168996), (4, 0.00020046500647215646), (5, 1.9798020969174002e-08), (6, 0.00029997990587762456)]
(56.0,): [(0, 0.0001096067897458676), (1, 0), (2, 0), (3, 0.0011999922000286), (4, 0), (5, 0), (6, 0)]
(57.0,): [(0, 0.0002010690706982891), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(58.0,): [(0, 0.006901427128154502), (1, 0), (2, 1.0195549966700805e-06), (3, 0), (4, 0.00019999979995672415), (5, 0.0001), (6, 0.0)]
(59.0,): [(0, 0), (1, 0.0009161625838874803), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(60.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0.0), (6, 9.999005043731106e-05)]
(61.0,): [(0, 0), (1, 0), (2, 0.00020083122292570215), (3, 0), (4, 0), (5, 0), (6, 0)]
(63.0,): [(0, 0.00010026715077907572), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(64.0,): [(0, 0.00010035611024677772), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(65.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.0), (5, 0), (6, 0)]
(66.0,): [(0, 3.964404221961604e-08), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(68.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0.0001), (5, 0), (6, 0)]
(69.0,): [(0, 0), (1, 0), (2, 0), (3, 0.0), (4, 0.00010000993941617127), (5, 0), (6, 0.00010018691984168522)]
(72.0,): [(0, 0.00020083584022336943), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(74.0,): [(0, 0.00010113839635357177), (1, 0), (2, 0), (3, 0.0015999784001519996), (4, 0), (5, 0), (6, 0)]
(75.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.00020016673372120304)]
(78.0,): [(0, 3.0676766129829237e-07), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(79.0,): [(0, 0), (1, 0), (2, 1.9801916044305277e-08), (3, 0), (4, 0), (5, 0), (6, 0)]
(80.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0.0003999993999320951), (6, 0)]
(81.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0.00010000990098010001), (6, 0)]
(82.0,): [(0, 0.0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.0002001775798835915)]
(83.0,): [(0, 0.00010002013340764707), (1, 0), (2, 0.0), (3, 0), (4, 0.0), (5, 0), (6, 0)]
(84.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 4.949796129459117e-08), (5, 0), (6, 0)]
(87.0,): [(0, 0), (1, 0.0), (2, 0), (3, 0), (4, 7.037842640281326e-11), (5, 0), (6, 0)]
(88.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0.00029999890607222815), (6, 0)]
(89.0,): [(0, 0), (1, 0.0002000295989802981), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(90.0,): [(0, 0), (1, 0), (2, 0), (3, 0.0), (4, 0), (5, 0), (6, 0.0001019217670339601)]
(91.0,): [(0, 0), (1, 0.0), (2, 0), (3, 0.0), (4, 0), (5, 0.00010190143813354317), (6, 0)]
(92.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 1.2623115100275525e-06), (6, 0)]
(93.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0.00010175431265480809)]
(94.0,): [(0, 2.7797285001434695e-07), (1, 1.9601990491266733e-12), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(97.0,): [(0, 0.000300009252722441), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0)]
(98.0,): [(0, 0), (1, 0.0), (2, 9.999005127860985e-05), (3, 0), (4, 0), (5, 0), (6, 0)]
(99.0,): [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 5.920577659957571e-07), (6, 0)]
(100.0,): [(0, 2.939153922439222e-06), (1, 0.0013920637451417853), (2, 3.9549982904505874e-08), (3, 1.9792952132822057e-08), (4, 1.7761533668439547e-07), (5, 1.680183255990612e-07), (6, 1.3809160805883612e-07)]
TEST: 87
[[41. 61. 43. ... 55. 57. 45.]
 [43.  9. 37. ... 73. 17. 74.]
 [53. 29. 54. ... 70. 71. 71.]
 ...
 [56. 72. 77. ... 77. 65. 74.]
 [12. 66.  0. ... 34. 26. 27.]
 [87. 84. 17. ... 76. 76. 75.]]
[0.3036406555630294, 0.1518203277815147, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202, 0.22773049167227202]
Stopping bidding_rl_delta_q_learning

Next I run the same code again but this time with the random agent.

name = "bidding_rl_delta_random"
env_name = "StaticDisplayAdvertising-v0"
num_repeats = 10
agent = random_agent
print("Starting {}".format(name))
env = gym.make(env_name)
rewards_buffer, bid_buffer = train_agent(env, agent, num_repeats)
save(average(rewards_buffer), name + ".json")
save(average(bid_buffer), name + "_bid.json")
print("Stopping {}".format(name))
Starting bidding_rl_delta_random
TEST: 49
[[ 0.  0. 14. ...  9. 21.  1.]
 [ 0.  7. 11. ...  0.  0.  0.]
 [ 0. 13. 24. ... 14. 54.  0.]
 ...
 [27.  1.  9. ...  0.  0.  0.]
 [ 2.  0.  9. ...  3. 12.  4.]
 [24. 10.  0. ...  1.  0. 21.]]
[0.14806317863352023, 0.16286949649687227, 0.15472602167202865, 0.23208903250804297, 0.24369348413344513, 0.24369348413344513, 0.24369348413344513, 0.21932413572010062, 0.20835792893409558, 0.19794003248739078, 0.20783703411176035, 0.20783703411176035, 0.19744518240617231, 0.29616777360925844, 0.28135938492879553, 0.28135938492879553, 0.4220390773931933, 0.3798351696538739, 0.18991758482693696, 0.1804217055855901, 0.18944279086486962, 0.18944279086486962, 0.20838706995135658, 0.10419353497567829, 0.10940321172446221, 0.16410481758669332, 0.15589957670735866, 0.14030961903662278, 0.14732509998845392, 0.2209876499826809, 0.243086414980949, 0.23093209423190156, 0.24247869894349663, 0.2667265688378463, 0.4000898532567694, 0.3800853605939309, 0.39908962862362746, 0.19954481431181373, 0.2095220550274044, 0.18856984952466396, 0.16971286457219756, 0.2545692968582964, 0.280026226544126, 0.3080288491985386, 0.33883173411839257, 0.16941586705919628, 0.15247428035327665, 0.22871142052991497, 0.2401469915564107, 0.25215434113423124, 0.12607717056711562, 0.06303858528355781, 0.05988665601937992, 0.05689232321841092, 0.05973693937933147, 0.06272378634829805, 0.056451407713468245, 0.06209654848481507, 0.06209654848481507, 0.031048274242407536, 0.03260068795452791, 0.04890103193179187, 0.051346083528381464, 0.0770191252925722, 0.06931721276331498, 0.06238549148698349, 0.06238549148698349, 0.09357823723047524, 0.10293606095352277, 0.15440409143028416, 0.14668388685876993, 0.2200258302881549, 0.2200258302881549, 0.19802324725933942, 0.2079244096223064, 0.2079244096223064, 0.3118866144334596, 0.1559433072167298, 0.14034897649505684, 0.15438387414456253, 0.23157581121684379, 0.21999702065600157, 0.19799731859040143, 0.2078971845199215, 0.3118457767798823, 0.2962534879408881, 0.14812674397044406, 0.15553308116896628, 0.15553308116896628, 0.1633097352274146, 0.1796407087501561, 0.1796407087501561, 0.2694610631252341, 0.40419159468785115, 0.4446107541566363, 0.6669161312349545, 0.73360774435845, 0.366803872179225, 0.3301234849613025, 0.29711113646517223]
Stopping bidding_rl_delta_random

Results

In the next two plots I present the sum of the rewards in an episode, over 200 episode, averaged over 10 runs. You’d need to perform more averaging to the the plots smoother.

You can see that the RL based agent quickly learns where to position the bid amount in order to maximize the reward.

The second image shows the bid amount changes over time.

%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd

data_files = [("Random", "bidding_rl_delta_random.json"),
              ("Q-Learning", "bidding_rl_delta_q_learning.json")]

fig, ax = plt.subplots()
for j, (name, data_file) in enumerate(data_files):
    df = pd.read_json(data_file)
    x = range(len(df))
    y = df.sort_index().values
    ax.plot(x,
            y,
            linestyle='solid',
            label=name)
ax.set_xlabel('Episode')
ax.set_ylabel('Averaged Sum of Rewards')
ax.legend(loc='lower right')
plt.show()
data_files = [("Random", "bidding_rl_delta_random_bid.json"),
              ("Q-Learning", "bidding_rl_delta_q_learning_bid.json")]

fig, ax = plt.subplots()
for j, (name, data_file) in enumerate(data_files):
    df = pd.read_json(data_file)
    x = range(len(df))
    y = df.sort_index().values
    ax.plot(x,
            y,
            linestyle='solid',
            label=name)
ax.set_xlabel('Episode')
ax.set_ylabel('Advertising bid')
ax.legend(loc='lower right')
plt.show()

Further Work

In the book I also include an example with real data. You need to download and process an external repository to get that data, so it’s a bit cumbersome to do here.

But the the general idea is that we want a better simulation of real bids. There’s lots of advanced ways of building models, some of which I touch upon in other workshops, but a simple solution is to assume that the state isn’t affected by your actions, then you can replay independent bid events and treat that as your simulation. Of course, this isn’t great, since you could do the same thing with pure ML. But I hope that you can see that RL is a general framework for learning online.

TODO: In the future I will try and include some data in the package, which isn’t working right now for some reason.