After my previous RASP experiments, I wondered: just how far can we take RASP? For instance, can we make a transformer that solves a tabular Q-learning environment?
Well, let's find out.
!pip install -qqq git+https://github.com/chalk-diagrams/chalk git+https://github.com/srush/RASPy
Installing build dependencies ... ents to build wheel ... etadata (pyproject.toml) ... etadata (setup.py) ... etadata (setup.py) ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.1/67.1 kB 1.1 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.5/62.5 kB 3.3 MB/s eta 0:00:00
s (pyproject.toml) ...
Same helper functions as last time:
from raspy import *
from raspy.rasp import Seq, Sel, SOp, Key, Query
from raspy.visualize import draw_all, draw, draw_sel
from chalk import *
from colour import Color
from raspy.visualize import word
import random
import matplotlib.pyplot as plt
def draw(c_inp=Color("white"), c_att=Color("white"), c_back=Color("white"), c_ffn=Color("white")):
= box("Input", c_inp).named("inp") / vstrut(1) / (rectangle(3, 4).fill_color(c_back).named("main") + ( box("Feed Forward", c_ffn).named("ffn") / vstrut(1) / box("Attention", c_att).named("att")).center_xy()) / vstrut(1) / box("Final").named("final")
d return d.connect_outside("inp", "main").connect_outside("ffn", "att").connect_outside("main", "final")
def draw_att():
= rectangle(2.5, 2.5)
d = d.beside(box2("key", green).rotate_by(0.25).named("key"), -unit_x)
d = d.beside(box2("query", orange).named("query"), -unit_y)
d = d.beside(box2("value", red).rotate_by(-0.25).named("value"), unit_x)
d = d.beside(box2("output").named("output"), unit_y)
d = d + rectangle(0.4,0.4).fill_color(black).named("mid").translate(0, -0.5)
d = d + rectangle(0.4,0.4).fill_color(black).named("mid2").translate(0, 0.5)
d = d.connect_perim("key", "mid", unit_x, -unit_x).connect_outside("query", "mid").connect_outside("mid", "mid2").connect_perim("value", "mid2", -unit_x, unit_x).connect_outside("mid2", "output")
d return d
= key(indices) < query(indices)
before
def atoi(seq=tokens):
return seq.map(lambda x: ord(x) - ord('0'))
def itatoi(seq=tokens):
return seq.map(lambda x: int(x))
def ititoa(seq=tokens):
return seq.map(lambda x: chr(x))
def cumsum(seq=tokens):
= (before | (key(indices) == query(indices))).value(seq)
x return x.name("cumsum")
def index(i, seq=tokens):
= (key(indices) == query(i)).value(seq)
x return x.name("index")
def shift(i=1, default="_", seq=tokens):
= (key(indices) == query(indices - i)).value(seq, default)
x return x.name("shift")
def shiftl(i=1, default="_", seq=tokens):
= (key(indices) == query(indices + i)).value(seq, default)
x return x.name("shiftl")
def shift_to_one(default="_", seq=tokens):
= (key(indices) == query(indices + 35)).value(seq, default)
x return x.name("shift_to_one")
def shiftls(i=1, default="_", seq=tokens):
= (key(indices) == query(indices + i - 1)).value(seq, default)
x return x.name("shiftls")
def shake_shift(i=1, default="_", seq=tokens):
= (key(indices) == query(indices - i + 3)).value(seq, default)
x = (key(indices) == query(indices + i - 3)).value(x, default)
x return x.name("shakeshift")
def lfsr_shift(seq=tokens):
= (key(indices) == query(indices + 10)).value(seq, 0)
x return x.name("lfsr_shift")
def minimum(seq=tokens):
= before & (key(seq) == query(seq))
sel1 = key(seq) < query(seq)
sel2 = (sel1 | sel2).value(1)
less = (key(less) == query(0)).value(seq)
x return x.name("min")
def first(q, seq=tokens):
return minimum(where(seq == q, indices, 99))
def ralign(default="-", sop=tokens):
= (key(sop) == query("_")).value(1)
c = (key(indices + c) == query(indices)).value(sop, default)
x return x.name("ralign")
def split(v, i, sop=tokens, default="0"):
= (key(sop) == query(v)).value(indices)
split_point if i:
= ralign(default, where(indices < split_point, sop, "_"))
x return x
else:
= where(indices > split_point, sop, default)
x return x
def slide(match, seq=tokens):
= cumsum(match)
x = ((key(x) == query(x + 1)) & (key(match) == query(True))).value(seq)
y = where(match, seq, y)
seq return seq.name("slide")
def add(sop=tokens):
# 0) Parse and add
= atoi(split("+", True, sop)) \
x + atoi(split("+", False, sop))
# 1) Check for carries
= shift(-1, "0", where(x > 9, "1", where(x == 9, "<", "0")))
gets_carry
# 2) In parallel, slide carries to their column
= atoi(slide(gets_carry != "<", gets_carry))
gets_carry
# 3) Add in carries.
return (x + gets_carry) % 10
def get_length(sop=tokens):
= (key(1) == query(1)).value(1)
length return length.name("length")
def indexof(v, sop=tokens):
= get_length(sop)
length = where((sop == v), indices, length) # Replace everything but the token with the length, and the token with its index
replaced return minimum(replaced) # the minimum value is the first index!
Ok, let's define the actual game! The setup is going to be as follows: it's gonna be a 2x3 world, where the agent starts in the bottom-left corner, the goal is in the bottom-right corner, and the bottom-middle tile is an impassable wall. Thus, it looks something like this:
= [[0, 0, 0], [2, 1, 3]]
environment plt.imshow(environment)
<matplotlib.image.AxesImage at 0x7f00eeb47910>
Where the light green is the agent, the blue is the rock, and the yellow is the goal.
What do we need in order to do Q-learning on this? Well, first of all, we need a Q-table: that is, a table which represents our estimate of how good every action is given a certain state.
So, let's create a string representing the entire game state:
= "000210400004000040000400004000040000" initial_game_state
This string might look a bit confusing at first glance. In reality, we just split it up using the 4 as a delimiter, getting us the following sections:
So that's it! Game state plus Q table. We're ready to go.
We're going to be implementing the simplest possible tabular Q-learning algorithm, using epsilon-greedy...oh wait a second, how on earth are we going to get random numbers inside a transformer???
Well, of course we could just generate a random string of numbers in Python and feed it in after the game state so the transformer has a source of randomness. On the other hand, we could implement some decent pseudorandom number generator—like the Mersenne Twister. A good compromise between coolness and implementation difficulty is the Linear-feedback shift register—basically, you start with sixteen numbers that are 0 or 1, shift them all over to the right, and do a bunch of XORs to re-seed the newly empty zeroth array index. This doesn't actually give you random numbers, but the sequence of bits will have a long enough cycle that, for our purposes, will be random enough. Ok, let's do it!
The first thing we'll need to do is implement XOR:
def xor(comp, sop=tokens):
= where((sop == 7), sop, sop)
result sum = comp + result
= sum % 2
xor_result return xor_result.name("xor")
And now we're ready to implement the linear-feedback shift register:
def lfsr(sop=tokens):
"""
Length-11 Linear-feedback shift register, with taps at positions 11 and 9.
The period of the cycle is 2047, which is good enough for our purposes!
"""
= itatoi(sop)
seed
= shift(1, 0, seed)
right_shifted
= index(10, seed)
tap_11 = index(8, seed)
tap_9
= xor(tap_11, tap_9)
xored
= lfsr_shift(xored)
xor_result_only_in_first_index
= xor_result_only_in_first_index + right_shifted
new_seed
return new_seed.name("lfsr")
Let's test it:
= "00010100001"
seed lfsr()(seed)
= ""
rng_bits for _ in range(100):
= ''.join(str(i) for i in lfsr()(seed).val)
seed += seed[9]
rng_bits
rng_bits
0001010001001000101011010100001100001001111001011100111001011110111001001010111011000010101110010000
Looks pretty random to me!!
Ok, now we're going to implement a few helper functions. They're not particularly interesting, but will be useful for our main game "loop" (hahaha, there's no loops in RASP!)
First things first, we're gonna have a simple code representing the game state:
def get_state_id(sop=tokens):
"""
Assumes a string of the full game state
"""
# In the flattened game state, the state id is just the index of the 2 (the agent)
= indexof(2, sop)
state_id
return state_id
0, 0, 0, 0, 2, 0, 0, 0, 0, 0]) get_state_id()([
def c_and(other_bit, sop=tokens):
"""
Improved `and` calculation
"""
sum = other_bit + sop
= where((sum == 2), sum, 1) - 1
x
return x
Two functions that add a certain value to a certain index of a sequence (don't ask why we need two)
def add_to_index(index, value, sop=tokens):
= shift(get_length(sop), value, sop)
value_to_add
= itatoi(shift(index, 0, shiftl(get_length(sop) - 1, 0, value_to_add)))
value_to_add_in_correct_place
return value_to_add_in_correct_place + sop
def ati2(index, value, sop=tokens):
= where(indices >= 0, value, sop)
value_to_add
= itatoi(shift(index, 0, shiftl(get_length(sop) - 1, 0, value_to_add)))
value_to_add_in_correct_place
return value_to_add_in_correct_place + sop
3, 10)([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) ati2(
def update_game_state(new_state, sop=tokens):
# We want to find the current location of the agent (the 2), and set it to zero
# Then, we want to set the value of the updated agent location index to 2
= indexof(2, sop)
current_agent_location = add_to_index(new_state, 2, sop)
x = add_to_index(current_agent_location, -2, x)
x
return x
import sys
10000) sys.setrecursionlimit(
And with all that out of the way, we're ready to implement our Q-learning RL agent!!! First, let's define the "transition matrix" and the "reward function":
= "00130121122530333333"
transitions = "11111111111111110000" rewards
In the above, transitions[state * 4 + action]
tells us
what next state we end up in given that we're currently in state
state
and take action action
.
rewards[state * 4 + action]
similarly tells us what reward
we get.
And finally, the gorgeous joint environment update/training loop:
def step(next_states, rewards, sop=tokens):
# Split up the game state into a sequence of ints, rather than characters
= atoi(sop)
x # x is now [0, 0, 0, 2, 1, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 5, 4, 1, 4, 1, 4]
# However, actually using 4s as delimiters is a bit difficult--it would be
# way nicer if each delimiter was unique, so we could easily access the
# relevant part of the game state. So, let's add one million to each delimiter.
# We don't want to add a super small number, lest it get confused with state
# value estimates (normally, they would be negative, but we store them
# using positive values here due to what I initially thought was a limitation
# of RASPy but later turned out to be false, and is now an aesthetic choice.)
= [0, 0, 0, 0, 0, 0, 1000000, 0, 0, 0, 0, 1000001, 0, 0, 0, 0, 1000002, 0, 0, 0, 0, 1000003, 0, 0, 0, 0, 1000004, 0, 0, 0, 0, 1000005, 0, 0, 0, 0, 1000006, 0, 1000007, 0, 1000008, 0, 1000009]
divider_addition
= itatoi(x + divider_addition)
full_game_state # The full game state is now the eyewatering [0, 0, 0, 2, 1, 0, 1000004, 0, 0, 0, 0, 1000005, 0, 0, 0, 0, 1000006, 0, 0, 0, 0, 1000007, 0, 0, 0, 0, 1000008, 0, 0, 0, 0, 1000009, 0, 0, 0, 0, 1000010, 5, 1000011, 1, 1000012, 1, 1000013]
# Ok, now it's time to initialize our RNG.
= shift(-100, 0, full_game_state)
seed = [0,0,0,1,0,1,0,0,0,0,1]
seed_init = seed + seed_init
seed
# Time: 22s
# This is not a real loop, in the sense that there's no loop logic involved.
# This just saves me the trouble of copying and pasting the layer definitions
# many many times in a row, which is how I would've had to have done it otherwise.
# for i in range(1):
# First, let's get the current game state.
# 0 is (0, 0)
# 1 is (0, 1)
# 2 is (0, 2)
# 3 is (1, 0)
# 5 is (1, 2)
= get_state_id(full_game_state)
state_id
# Decide whether to take a random action or not
# To do this, read off three bits off our LFSR and see if they're all 1
# which gets us an epsilon of 0.125
= lfsr(seed)
seed = where(index(15, seed) == 1, 1, 0)
rng1 = lfsr(seed)
seed = where(index(15, seed) == 1, 1, 0)
rng2 = lfsr(seed)
seed = where(index(15, seed) == 1, 1, 0)
rng3
= c_and(rng1, c_and(rng2, rng3))
rng_less_than_epsilon
# Figure out which action to take
# First, generate a random action (we always do this, even if we don't end up taking it)
= lfsr(seed)
seed = where(index(15, seed) == 1, 1, 0)
rng4 = lfsr(seed)
seed = where(index(15, seed) == 1, 1, 0)
rng5 = rng4 * 2 + rng5 + 2 # [0/1] * 2 + [0/1] is uniform from {0, 1, 2, 3}
random_action
# Time: 7m
# Next, get the best action
# First, extract the relevant section of the q-table
= indexof(state_id + 1000004, full_game_state) # this is why we did all of that delimiter nonsense
state_lookup_index = index(state_lookup_index + 1, full_game_state)
action_0_value = index(state_lookup_index + 2, full_game_state)
action_1_value = index(state_lookup_index + 3, full_game_state)
action_2_value = index(state_lookup_index + 4, full_game_state)
action_3_value
# Then, get the best action to take (in this case, the one with minimum value)
= shift(-42, 999999999, action_0_value) # [val(a0), inf, inf, inf, ...]
shifted_a0v = shift(1, 999999999, shift(-42, 999999999, action_1_value)) # [inf, val(a1), inf, inf, ...]
shifted_a1v = shift(2, 999999999, shift(-42, 999999999, action_2_value)) # [inf, inf, val(a2), inf, ...]
shifted_a2v = shift(3, 999999999, shift(-42, 999999999, action_3_value)) # [inf, inf, inf, val(a3), ...]
shifted_a3v
= shifted_a0v + shifted_a1v + shifted_a2v + shifted_a3v
action_sum # [3*inf + val(a0), 3*inf + val(a1), 3*inf + val(a2), 3*inf + val(a3), 4*inf, 4*inf, 4*inf, ...]
# So the index of the minimum is the best action
= indexof(minimum(action_sum), action_sum) + 1
best_action
= shift(-35, 0, best_action) # [B, 0, 0, 0, ...]
shifted_best_action = shift(1, 0, shift(-35, 0, random_action)) # [0, R, 0, 0, ...]
shifted_random_action
= shifted_best_action + shifted_random_action
combined_actions # [B, R, 0, 0, 0, ...]
= index(rng_less_than_epsilon, combined_actions)
action # Take the zeroth index if rng is not less than epsilon (e.g. best action B)
# Take the first index if rng is less than epsilon (e.g. random action R)
# Lookup index for next state and reward as described earlier
= state_id * 4 + action
lookup_index
= atoi(index(lookup_index, next_states))
next_state = atoi(index(lookup_index, rewards))
reward
# Get the best value of the next action
= indexof(next_state + 1000004, full_game_state)
next_state_lookup_index = minimum(shift(-32, 999999, shift(32, 999999, shiftl(next_state_lookup_index + 1, 999999, full_game_state))))
best_value_of_next_action
= reward - best_value_of_next_action # assuming learning rate and gamma are both 1, lmao
bellman_update # also we store rewards as negative rewards for "simplicity"
# Update the Q table
= ati2(state_lookup_index + action + 1, bellman_update, full_game_state)
full_game_state
# Update the environment
= update_game_state(next_state, full_game_state)
full_game_state
return full_game_state
step(transitions, rewards)(initial_game_state)