GridWorld MDP Tutorial

In this tutorial, we provide a simple example of how to define a Markov decision process (MDP) using the POMDPS.jl interface. We will then solve the MDP using value iteration and Monte Carlo tree search (MCTS). We will walk through constructing the MDP using the explicit interface which invovles defining a new type for the MDP and then extending different components of the POMDPs.jl interface for that type.

Dependencies

We need a few modules in order to run this example. All of the models can be added by running the following command in the Julia REPL:

using Pkg

Pkg.add("POMDPs")
Pkg.add("POMDPTools")
Pkg.add("DiscreteValueIteration")
Pkg.add("MCTS")

If you already had the models installed, it is prudent to update them to the latest version:

Pkg.update()

Now that we have the models installed, we can load them into our workspace:

using POMDPs
using POMDPTools
using DiscreteValueIteration
using MCTS

Problem Overview

In Grid World, we are trying to control an agent who has trouble moving in the desired direction. In our problem, we have four reward states within the a grid. Each position on the grid represents a state, and the positive reward states are terminal (the agent stops recieving reward after reaching them and performing an action from that state). The agent has four actions to choose from: up, down, left, right. The agent moves in the desired direction with a probability of $0.7$, and with a probability of $0.1$ in each of the remaining three directions. If the agent bumps into the outside wall, there is a penalty of $1$ (i.e. reward of $-1$). The problem has the following form:

Grid World

Defining the Grid World MDP Type

In POMDPs.jl, an MDP is defined by creating a subtype of the MDP abstract type. The types of the states and actions for the MDP are declared as parameters of the MDP type. For example, if our states and actions are both represented by integers, we can define our MDP type as follows:

struct MyMDP <: MDP{Int64, Int64} # MDP{StateType, ActionType}
    # fields go here
end

In our grid world problem, we will represent the states using a custom type that designates the x and y coordinate within the grid. The actions will by represented by a symbol.

GridWorldState

There are numerous ways to represent the state of the agent in a grid world. We will use a custom type that designates the x and y coordinate within the grid.

struct GridWorldState
    x::Int64
    y::Int64
end

To help us later, let's extend the == for our GridWorldStat:

function Base.:(==)(s1::GridWorldState, s2::GridWorldState)
    return s1.x == s2.x && s1.y == s2.y
end

GridWorld Actions

Since our action is the direction the agent chooses to go (i.e. up, down, left, right), we can use a Symbol to represent it. Note that in this case, we are not defining a custom type for our action, instead we represent it directly with a symbol. Our actions will be :up, :down, :left, and :right.

GridWorldMDP

Now that we have defined our types for states and actions, we can define our MDP type. We will call it GridWorldMDP and it will be a subtype of MDP{GridWorldState, Symbol}.

struct GridWorldMDP <: MDP{GridWorldState, Symbol}
    size_x::Int64 # x size of the grid
    size_y::Int64 # y size of the grid
    reward_states_values::Dict{GridWorldState, Float64} # Dictionary mapping reward states to their values
    hit_wall_reward::Float64 # reward for hitting a wall
    tprob::Float64 # probability of transitioning to the desired state
    discount_factor::Float64 # disocunt factor
end

We can define a constructor for our GridWorldMDP to make it easier to create instances of our MDP.

function GridWorldMDP(;
    size_x::Int64=10,
    size_y::Int64=10,
    reward_states_values::Dict{GridWorldState, Float64}=Dict(
        GridWorldState(4, 3) => -10.0,
        GridWorldState(4, 6) => -5.0,
        GridWorldState(9, 3) => 10.0,
        GridWorldState(8, 8) => 3.0),
    hit_wall_reward::Float64=-1.0,
    tprob::Float64=0.7,
    discount_factor::Float64=0.9)
    return GridWorldMDP(size_x, size_y, reward_states_values, hit_wall_reward, tprob, discount_factor)
end
Main.GridWorldMDP

To help us visualize our MDP, we can extend show for our GridWorldMDP type:

function Base.show(io::IO, mdp::GridWorldMDP)
    println(io, "Grid World MDP")
    println(io, "\tSize x: $(mdp.size_x)")
    println(io, "\tSize y: $(mdp.size_y)")
    println(io, "\tReward states:")
    for (key, value) in mdp.reward_states_values
        println(io, "\t\t$key => $value")
    end
    println(io, "\tHit wall reward: $(mdp.hit_wall_reward)")
    println(io, "\tTransition probability: $(mdp.tprob)")
    println(io, "\tDiscount: $(mdp.discount_factor)")
end

Now lets create an instance of our GridWorldMDP:

mdp = GridWorldMDP()
Grid World MDP
	Size x: 10
	Size y: 10
	Reward states:
		Main.GridWorldState(4, 6) => -5.0
		Main.GridWorldState(8, 8) => 3.0
		Main.GridWorldState(4, 3) => -10.0
		Main.GridWorldState(9, 3) => 10.0
	Hit wall reward: -1.0
	Transition probability: 0.7
	Discount: 0.9
Note

In this definition of the problem, our coordiates start in the bottom left of the grid. That is GridState(1, 1) is the bottom left of the grid and GridState(10, 10) would be on the right of the grid with a grid size of 10 by 10.

Grid World State Space

The state space in an MDP represents all the states in the problem. There are two primary functionalities that we want our spaces to support. We want to be able to iterate over the state space (for Value Iteration for example), and sometimes we want to be able to sample form the state space (used in some POMDP solvers). In this notebook, we will only look at iterable state spaces.

Since we can iterate over elements of an array, and our problem is small, we can store all of our states in an array. We also have a terminal state based on the definition of our problem. We can represent that as a location outside of the grid (i.e. (-1, -1)).

function POMDPs.states(mdp::GridWorldMDP)
    states_array = GridWorldState[]
    for x in 1:mdp.size_x
        for y in 1:mdp.size_y
            push!(states_array, GridWorldState(x, y))
        end
    end
    push!(states_array, GridWorldState(-1, -1)) # Adding the terminal state
    return states_array
end

Let's view some of the states in our state space:

@show states(mdp)[1:5]
5-element Vector{Main.GridWorldState}:
 Main.GridWorldState(1, 1)
 Main.GridWorldState(1, 2)
 Main.GridWorldState(1, 3)
 Main.GridWorldState(1, 4)
 Main.GridWorldState(1, 5)

We also need a other functions related to the state space.

# Check if a state is the terminal state
POMDPs.isterminal(mdp::GridWorldMDP, s::GridWorldState) = s == GridWorldState(-1, -1)

# Define the initial state distribution (always start in the bottom left)
POMDPs.initialstate(mdp::GridWorldMDP) = Deterministic(GridWorldState(1, 1))

# Function that returns the index of a state in the state space
function POMDPs.stateindex(mdp::GridWorldMDP, s::GridWorldState)
    if isterminal(mdp, s)
        return length(states(mdp))
    end

    @assert 1 <= s.x <= mdp.size_x "Invalid state"
    @assert 1 <= s.y <= mdp.size_y "Invalid state"

    si = (s.x - 1) * mdp.size_y + s.y
    return si
end

Large State Spaces

If your problem is very large we probably do not want to store all of our states in an array. We can create an iterator using indexing functions to help us out. One way of doing this is to define a function that returns a state from an index and then construct an iterator. This is an example of how we can do that for the Grid World problem.

Note

If you run this section, you will redefine the states(::GridWorldMDP) that we just defined in the previous section.

    # Define the length of the state space, number of grid locations plus the terminal state
    Base.length(mdp::GridWorldMDP) = mdp.size_x * mdp.size_y + 1

    # `states` now returns the mdp, which we will construct our iterator from
    POMDPs.states(mdp::GridWorldMDP) = mdp

    function Base.getindex(mdp::GridWorldMDP, si::Int) # Enables mdp[si]
        @assert si <= length(mdp) "Index out of bounds"
        @assert si > 0  "Index out of bounds"

        # First check if we are in the terminal state (which we define as the last state)
        if si == length(mdp)
            return GridWorldState(-1, -1)
        end

        # Otherwise, we need to calculate the x and y coordinates
        y = (si - 1) % mdp.size_y + 1
        x = div((si - 1), mdp.size_y) + 1
        return GridWorldState(x, y)
    end

    function Base.getindex(mdp::GridWorldMDP, si_range::UnitRange{Int}) # Enables mdp[1:5]
        return [getindex(mdp, si) for si in si_range]
    end

    Base.firstindex(mdp::GridWorldMDP) = 1 # Enables mdp[begin]
    Base.lastindex(mdp::GridWorldMDP) = length(mdp) # Enables mdp[end]

    # We can now construct an iterator
    function Base.iterate(mdp::GridWorldMDP, ii::Int=1)
        if ii > length(mdp)
            return nothing
        end
        s = getindex(mdp, ii)
        return (s, ii + 1)
    end

Similar to above, let's iterate over a few of the states in our state space:

@show states(mdp)[1:5]
@show mdp[begin]
@show mdp[end]
Main.GridWorldState(-1, -1)

Grid World Action Space

The action space is the set of all actions availiable to the agent. In the grid world problem the action space consists of up, down, left, and right. We can define the action space by implementing a new method of the actions function.

POMDPs.actions(mdp::GridWorldMDP) = [:up, :down, :left, :right]

Similar to the state space, we need a function that returns an index given an action.

function POMDPs.actionindex(mdp::GridWorldMDP, a::Symbol)
    @assert in(a, actions(mdp)) "Invalid action"
    return findfirst(x -> x == a, actions(mdp))
end

Grid World Transition Function

MDPs often define the transition function as $T(s^{\prime} \mid s, a)$, which is the probability of transitioning to state $s^{\prime}$ given that we are in state $s$ and take action $a$. For the POMDPs.jl interface, we define the transition function as a distribution over the next states. That is, we want $T(\cdot \mid s, a)$ which is a function that takes in a state and an action and returns a distribution over the next states.

For our grid world example, there are only a few states to which the agent can transition and thus only a few states with nonzero probaility in $T(\cdot \mid s, a)$. We can use the SparseCat distribution to represent this. The SparseCat distribution is a categorical distribution that only stores the nonzero probabilities. We can define our transition function as follows:

function POMDPs.transition(mdp::GridWorldMDP, s::GridWorldState, a::Symbol)
    # If we are in the terminal state, we stay in the terminal state
    if isterminal(mdp, s)
        return SparseCat([s], [1.0])
    end

    # If we are in a positive reward state, we transition to the terminal state
    if s in keys(mdp.reward_states_values) && mdp.reward_states_values[s] > 0
        return SparseCat([GridWorldState(-1, -1)], [1.0])
    end

    # Probability of going in a direction other than the desired direction
    tprob_other = (1 - mdp.tprob) / 3

    new_state_up = GridWorldState(s.x, min(s.y + 1, mdp.size_y))
    new_state_down = GridWorldState(s.x, max(s.y - 1, 1))
    new_state_left = GridWorldState(max(s.x - 1, 1), s.y)
    new_state_right = GridWorldState(min(s.x + 1, mdp.size_x), s.y)

    new_state_vector = [new_state_up, new_state_down, new_state_left, new_state_right]
    t_prob_vector = fill(tprob_other, 4)

    if a == :up
        t_prob_vector[1] = mdp.tprob
    elseif a == :down
        t_prob_vector[2] = mdp.tprob
    elseif a == :left
        t_prob_vector[3] = mdp.tprob
    elseif a == :right
        t_prob_vector[4] = mdp.tprob
    else
        error("Invalid action")
    end

    # Combine probabilities for states that are the same
    for i in 1:4
        for j in (i + 1):4
            if new_state_vector[i] == new_state_vector[j]
                t_prob_vector[i] += t_prob_vector[j]
                t_prob_vector[j] = 0.0
            end
        end
    end

    # Remove states with zero probability
    new_state_vector = new_state_vector[t_prob_vector .> 0]
    t_prob_vector = t_prob_vector[t_prob_vector .> 0]

    return SparseCat(new_state_vector, t_prob_vector)
end

Let's examline a few transitions:

@show transition(mdp, GridWorldState(1, 1), :up)
                                                                         SparseCat distribution           
                                                               ┌                                        ┐ 
   Main.GridWorldState(1, 2) ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.7   
   Main.GridWorldState(1, 1) ■■■■■■■■■■ 0.2                            
   Main.GridWorldState(2, 1) ■■■■■ 0.1                                 
                                                               └                                        ┘ 
@show transition(mdp, GridWorldState(1, 1), :left)
                                                                         SparseCat distribution           
                                                               ┌                                        ┐ 
   Main.GridWorldState(1, 2) ■■■■ 0.1                                  
   Main.GridWorldState(1, 1) ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.8   
   Main.GridWorldState(2, 1) ■■■■ 0.1                                  
                                                               └                                        ┘ 
@show transition(mdp, GridWorldState(9, 3), :right)
                                                                           SparseCat distribution           
                                                                 ┌                                        ┐ 
   Main.GridWorldState(-1, -1) ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 1   
                                                                 └                                        ┘ 
@show transition(mdp, GridWorldState(-1, -1), :down)
                                                                           SparseCat distribution           
                                                                 ┌                                        ┐ 
   Main.GridWorldState(-1, -1) ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 1   
                                                                 └                                        ┘ 

Grid World Reward Function

In our problem, we have a reward function that depends on the next state as well (i.e. if we hit a wall, we stay in the same state and get a reward of $-1$). We can still construct a reward function that only depends on the current state and action by using expectation over the next state. That is, we can define our reward function as $R(s, a) = \mathbb{E}_{s^{\prime} \sim T(\cdot \mid s, a)}[R(s, a, s^{\prime})]$.

# First, let's define the reward function given the state, action, and next state
function POMDPs.reward(mdp::GridWorldMDP, s::GridWorldState, a::Symbol, sp::GridWorldState)
    # If we are in the terminal state, we get a reward of 0
    if isterminal(mdp, s)
        return 0.0
    end

    # If we are in a positive reward state, we get the reward of that state
    # For a positive reward, we transition to the terminal state, so we don't have
    # to worry about the next state (i.g. hitting a wall)
    if s in keys(mdp.reward_states_values) && mdp.reward_states_values[s] > 0
        return mdp.reward_states_values[s]
    end

    # If we are in a negative reward state, we get the reward of that state
    # If the negative reward state is on the edge of the grid, we can also be in this state
    # and hit a wall, so we need to check for that
    r = 0.0
    if s in keys(mdp.reward_states_values) && mdp.reward_states_values[s] < 0
        r += mdp.reward_states_values[s]
    end

    # If we hit a wall, we get a reward of -1
    if s == sp
        r += mdp.hit_wall_reward
    end

    return r
end

# Now we can define the reward function given the state and action
function POMDPs.reward(mdp::GridWorldMDP, s::GridWorldState, a::Symbol)
    r = 0.0
    for (sp, p) in transition(mdp, s, a)
        r += p * reward(mdp, s, a, sp)
    end
    return r
end

Let's examine a few rewards:

@show reward(mdp, GridWorldState(1, 1), :up)
-0.20000000000000004
@show reward(mdp, GridWorldState(1, 1), :left)
-0.7999999999999999
@show reward(mdp, GridWorldState(9, 3), :right)
10.0
@show reward(mdp, GridWorldState(-1, -1), :down)
0.0
@show reward(mdp, GridWorldState(2, 3), :up)
0.0

Grid World Remaining Functions

We are almost done! We still need to define discount. Let's first use POMDPLinter to check if we have defined all the functions we need for DiscreteValueIteration:

using POMDPLinter

@show_requirements POMDPs.solve(ValueIterationSolver(), mdp)
false

As we expected, we need to define discount.

function POMDPs.discount(mdp::GridWorldMDP)
    return mdp.discount_factor
end

Let's check again:

@show_requirements POMDPs.solve(ValueIterationSolver(), mdp)
true

Solving the Grid World MDP (Value Iteration)

Now that we have defined our MDP, we can solve it using Value Iteration. We will use the ValueIterationSolver from the DiscreteValueIteration package. First, we construct the a Solver type which contains the solver parameters. Then we call POMDPs.solve to solve the MDP and return a policy.

# Initialize the problem (we have already done this, but just calling it again for completeness in the example)
mdp = GridWorldMDP()

# Initialize the solver with desired parameters
solver = ValueIterationSolver(; max_iterations=100, belres=1e-3, verbose=true)

# Solve for an optimal policy
vi_policy = POMDPs.solve(solver, mdp)
[Iteration 1   ] residual:         10 | iteration runtime:      0.218 ms, (  0.000218 s total)
[Iteration 2   ] residual:        6.3 | iteration runtime:      0.241 ms, (  0.000459 s total)
[Iteration 3   ] residual:       4.53 | iteration runtime:      0.236 ms, (  0.000694 s total)
[Iteration 4   ] residual:       3.21 | iteration runtime:      0.220 ms, (  0.000915 s total)
[Iteration 5   ] residual:       2.31 | iteration runtime:      0.223 ms, (   0.00114 s total)
[Iteration 6   ] residual:       1.62 | iteration runtime:      0.215 ms, (   0.00135 s total)
[Iteration 7   ] residual:       1.24 | iteration runtime:      0.222 ms, (   0.00157 s total)
[Iteration 8   ] residual:       1.06 | iteration runtime:      0.232 ms, (   0.00181 s total)
[Iteration 9   ] residual:      0.865 | iteration runtime:      0.227 ms, (   0.00203 s total)
[Iteration 10  ] residual:      0.657 | iteration runtime:      0.215 ms, (   0.00225 s total)
[Iteration 11  ] residual:      0.545 | iteration runtime:      0.220 ms, (   0.00247 s total)
[Iteration 12  ] residual:      0.455 | iteration runtime:      0.215 ms, (   0.00268 s total)
[Iteration 13  ] residual:      0.378 | iteration runtime:      0.214 ms, (    0.0029 s total)
[Iteration 14  ] residual:      0.306 | iteration runtime:      0.218 ms, (   0.00312 s total)
[Iteration 15  ] residual:      0.211 | iteration runtime:      0.229 ms, (   0.00335 s total)
[Iteration 16  ] residual:      0.132 | iteration runtime:      0.214 ms, (   0.00356 s total)
[Iteration 17  ] residual:     0.0778 | iteration runtime:      0.216 ms, (   0.00378 s total)
[Iteration 18  ] residual:     0.0437 | iteration runtime:      0.217 ms, (   0.00399 s total)
[Iteration 19  ] residual:     0.0237 | iteration runtime:      0.217 ms, (   0.00421 s total)
[Iteration 20  ] residual:     0.0125 | iteration runtime:      0.210 ms, (   0.00442 s total)
[Iteration 21  ] residual:    0.00649 | iteration runtime:      0.215 ms, (   0.00463 s total)
[Iteration 22  ] residual:    0.00332 | iteration runtime:      0.213 ms, (   0.00485 s total)
[Iteration 23  ] residual:    0.00167 | iteration runtime:      0.213 ms, (   0.00506 s total)
[Iteration 24  ] residual:   0.000834 | iteration runtime:      0.216 ms, (   0.00528 s total)

We can now use the policy to compute the optimal action for a given state:

s = GridWorldState(9, 2)
@show action(vi_policy, s)
:up
s = GridWorldState(8, 3)
@show action(vi_policy, s)
:right

Solving the Grid World MDP (MCTS)

Similar to the process with Value Iteration, we can solve the MDP using MCTS. We will use the MCTSSolver from the MCTS package.

# Initialize the problem (we have already done this, but just calling it again for completeness in the example)
mdp = GridWorldMDP()

# Initialize the solver with desired parameters
solver = MCTSSolver(n_iterations=1000, depth=20, exploration_constant=10.0)

# Now we construct a planner by calling POMDPs.solve. For online planners, the computation for the
# optimal action occurs in the call to `action`.
mcts_planner = POMDPs.solve(solver, mdp)

Similar to the value iteration policy, we can use the policy to compute the action for a given state:

s = GridWorldState(9, 2)
@show action(mcts_planner, s)
:up
s = GridWorldState(8, 3)
@show action(mcts_planner, s)
:right

Visualizing the Value Iteration Policy

We can visualize the value iteration policy by plotting the value function and the policy. We can use numerous plotting packages to do this, but we will use UnicodePlots for this example.

using UnicodePlots
using Printf

Value Function as a Heatmap

We can plot the value function as a heatmap. The value function is a function over the state space, so we need to iterate over the state space and store the value at each state. We can use the value function to evaluate the value function at a given state.

# Initialize the value function array
value_function = zeros(mdp.size_y, mdp.size_x)

# Iterate over the state space and store the value at each state
for s in states(mdp)
    if isterminal(mdp, s)
        continue
    end
    value_function[s.y, s.x] = value(vi_policy, s)
end

# Plot the value function
heatmap(value_function;
    title="GridWorld VI Value Function",
    xlabel="x position",
    ylabel="y position",
    colormap=:inferno
)
                 GridWorld VI Value Function     
                 ┌──────────┐  10 
              10 ▄▄▄▄▄▄▄▄ ┌──┐
                 ▄▄▄▄▄▄ ▄▄
   y position    ▄▄ ▄▄
                 ▄▄ ▄▄
               1 ▄▄▄▄ └──┘
                 └──────────┘ -7  
                  1       10      
                  x position      
Note

Rendering of unicode plots in the documentation is not optimal. For a better image, run this locally in a REPL.

Visualizing the Value Iteration Policy

One way to visualize the policy is to plot the action that the policy takes at each state.

# Initialize the policy array
policy_array = fill(:up, mdp.size_x, mdp.size_y)

# Iterate over the state space and store the action at each state
for s in states(mdp)
    if isterminal(mdp, s)
        continue
    end
    policy_array[s.x, s.y] = action(vi_policy, s)
end

# Let's define a mapping from symbols to unicode arrows
arrow_map = Dict(
    :up => " ↑ ",
    :down => " ↓ ",
    :left => " ← ",
    :right => " → "
)

# Plot the policy to the terminal, with the origin in the bottom left
@printf("        GridWorld VI Policy      \n")
for y in mdp.size_y+1:-1:0
    if y == mdp.size_y+1 || y == 0
        for xi in 0:10
            if xi == 0
                print("   ")
            elseif y == mdp.size_y+1
                print("___")
            else
                print("---")
            end
        end
    else
        for x in 0:mdp.size_x+1
            if x == 0
                @printf("%2d |", y)
            elseif x == mdp.size_x + 1
                print("|")
            else
                print(arrow_map[policy_array[x, y]])
            end
        end
    end
    println()
    if y == 0
        for xi in 0:10
            if xi == 0
                print("   ")
            else
                print(" $xi ")
            end
        end
    end
end
        GridWorld VI Policy
   ______________________________
10 | →  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓ |
 9 | →  →  →  →  ↓  ↓  ↓  →  ↓  ↓ |
 8 | →  →  →  →  →  ↓  ↓  ↑  ↓  ↓ |
 7 | →  →  →  →  →  →  ↓  ↓  ↓  ↓ |
 6 | →  ↓  ↓  →  →  →  ↓  ↓  ↓  ↓ |
 5 | →  →  →  →  →  →  →  ↓  ↓  ↓ |
 4 | →  →  →  →  →  →  →  →  ↓  ↓ |
 3 | →  ↓  ↓  →  →  →  →  →  ↑  ← |
 2 | →  →  →  →  →  →  →  →  ↑  ↑ |
 1 | →  →  →  →  →  →  ↑  ↑  ↑  ↑ |
   ------------------------------
    1  2  3  4  5  6  7  8  9  10

Seeing a Policy In Action

Another useful tool is to view the policy in action by creating a gif of a simulation. To accomplish this, we could use POMDPGifs. To use POMDPGifs, we need to extend the POMDPTools.render function to GridWorldMDP. Please reference Gallery of POMDPs.jl Problems for examples of this process.