Using Different Solvers

There are various solvers implemented for use out-of-the-box. Please reference the repository README for a list of MDP Solvers and POMDP Solvers implemented and maintained by the JuliaPOMDP community. We provide a few examples of how to use a small subset of these solvers.

Checking Requirements

Before using a solver, it is prudent to ensure the problem meets the requirements of the solver. Please reference the solver documentation for detailed information about the requirements of each solver.

We can use POMDPLInter to help us determine if we have all of the required components defined for a particular solver. However, not all solvers have the requirements implemented. If/when you encounter a solver that does not have the requirements implemented, please open an issue on the solver's repository.

Let's check if we have all of the required components of our problems for the QMDP solver.

using POMDPLinter
using QMDP

qmdp_solver = QMDPSolver()

println("Quick Crying Baby POMDP")
@show_requirements POMDPs.solve(qmdp_solver, quick_crying_baby_pomdp)

println("\nExplicit Crying Baby POMDP")
@show_requirements POMDPs.solve(qmdp_solver, explicit_crying_baby_pomdp)

println("\nTabular Crying Baby POMDP")
@show_requirements POMDPs.solve(qmdp_solver, tabular_crying_baby_pomdp)

println("\nGen Crying Baby POMDP")
# We don't have an actions(::GenGryingBabyPOMDP) implemented
try
    @show_requirements POMDPs.solve(qmdp_solver, gen_crying_baby_pomdp)
catch err_msg
    println(err_msg)
end
Quick Crying Baby POMDP
INFO: POMDPLinter requirements for solve(::QMDP.QMDPSolver, ::POMDP) and dependencies. ([✔] = implemented correctly; [X] = not implemented; [?] = could not determine)

For solve(::QMDP.QMDPSolver, ::POMDP):
  [No additional requirements]
For solve(::ValueIterationSolver, ::Union{MDP,POMDP}) (in solve(::QMDP.QMDPSolver, ::POMDP)):
  [✔] discount(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol})
  [✔] transition(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol}, ::Symbol, ::Symbol)
  [✔] reward(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol}, ::Symbol, ::Symbol, ::Symbol)
  [✔] stateindex(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol}, ::Symbol)
  [✔] actionindex(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol}, ::Symbol)
  [✔] actions(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol}, ::Symbol)
  [✔] length(::Array{Symbol1})
  [✔] support(::Deterministic{Symbol})
  [✔] pdf(::Deterministic{Symbol}, ::Symbol)
For ordered_states(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] states(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol})
For ordered_actions(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] actions(::UnderlyingMDP{QuickPOMDPs.QuickPOMDP{UUID("6d74353b-4751-4772-8680-9df9dfca7fc9"), Symbol, Symbol, Symbol, @NamedTuple{stateindex::Dict{Symbol, Int64}, isterminal::Bool, obsindex::Dict{Symbol, Int64}, states::Vector{Symbol}, observations::Vector{Symbol}, discount::Float64, actions::Vector{Symbol}, observation::Main.var"#2#5", actionindex::Dict{Symbol, Int64}, initialstate::Deterministic{Symbol}, transition::Main.var"#1#4", reward::Main.var"#3#6"}}SymbolSymbol})

Explicit Crying Baby POMDP
INFO: POMDPLinter requirements for solve(::QMDP.QMDPSolver, ::POMDP) and dependencies. ([✔] = implemented correctly; [X] = not implemented; [?] = could not determine)

For solve(::QMDP.QMDPSolver, ::POMDP):
  [No additional requirements]
For solve(::ValueIterationSolver, ::Union{MDP,POMDP}) (in solve(::QMDP.QMDPSolver, ::POMDP)):
  [✔] discount(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol})
  [✔] transition(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState, ::Symbol)
  [✔] reward(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState, ::Symbol, ::CryingBabyState)
  [✔] stateindex(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState)
  [✔] actionindex(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol}, ::Symbol)
  [✔] actions(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState)
  [✔] length(::Array{Main.CryingBabyState1})
  [✔] length(::Array{Symbol1})
  [✔] support(::Deterministic{Main.CryingBabyState})
  [✔] pdf(::Deterministic{Main.CryingBabyState}, ::CryingBabyState)
For ordered_states(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] states(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol})
For ordered_actions(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] actions(::UnderlyingMDP{Main.CryingBabyPOMDPMain.CryingBabyStateSymbol})

Tabular Crying Baby POMDP
INFO: POMDPLinter requirements for solve(::QMDP.QMDPSolver, ::POMDP) and dependencies. ([✔] = implemented correctly; [X] = not implemented; [?] = could not determine)

For solve(::QMDP.QMDPSolver, ::POMDP):
  [No additional requirements]
For solve(::ValueIterationSolver, ::Union{MDP,POMDP}) (in solve(::QMDP.QMDPSolver, ::POMDP)):
  [✔] discount(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64})
  [✔] transition(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64}, ::Int64, ::Int64)
  [✔] reward(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64}, ::Int64, ::Int64, ::Int64)
  [✔] stateindex(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64}, ::Int64)
  [✔] actionindex(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64}, ::Int64)
  [✔] actions(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64}, ::Int64)
  [✔] length(::UnitRange{Int64})
  [✔] support(::DiscreteDistribution{SubArray{Float64, 1, Array{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Int64}, true}})
  [✔] pdf(::DiscreteDistribution{SubArray{Float64, 1, Array{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Int64}, true}}, ::Int64)
For ordered_states(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] states(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64})
For ordered_actions(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] actions(::UnderlyingMDP{POMDPModels.TabularPOMDPInt64Int64})

Gen Crying Baby POMDP
INFO: POMDPLinter requirements for solve(::QMDP.QMDPSolver, ::POMDP) and dependencies. ([✔] = implemented correctly; [X] = not implemented; [?] = could not determine)

For solve(::QMDP.QMDPSolver, ::POMDP):
  [No additional requirements]
For solve(::ValueIterationSolver, ::Union{MDP,POMDP}) (in solve(::QMDP.QMDPSolver, ::POMDP)):
  [✔] discount(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol})
  [✔] transition(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState, ::Symbol)
  [✔] reward(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState, ::Symbol, ::CryingBabyState)
  [✔] stateindex(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState)
  [✔] actionindex(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol}, ::Symbol)
  [✔] actions(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol}, ::CryingBabyState)
  WARNING: Some requirements may not be shown because a MethodError was thrown.
For ordered_states(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] states(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol})
  WARNING: Some requirements may not be shown because a MethodError was thrown.
For ordered_actions(::Union{MDP,POMDP}) (in solve(::ValueIterationSolver, ::Union{MDP,POMDP})):
  [✔] actions(::UnderlyingMDP{Main.GenCryingBabyPOMDPMain.CryingBabyStateSymbol})
  WARNING: Some requirements may not be shown because a MethodError was thrown.
Note: Missing methods are often due to incorrect importing. You must explicitly import POMDPs functions to add new methods.

Throwing the first exception (from processing solve(::ValueIterationSolver, ::Union{MDP,POMDP}) requirements):

MethodError(POMDPs.actions, (Main.GenCryingBabyPOMDP(0.1, 0.8, 0.9, 0.8, 0.1, 0.0, 0.1, -10.0, -5.0, -0.5, 0.9),), 0x0000000000007c04)

Offline (SARSOP)

In this example, we will use the NativeSARSOP solver. The process for generating offline polcies is similar for all offline solvers. First, we define the solver with the desired parameters. Then, we call POMDPs.solve with the solver and the problem. We can query the policy using the action function.

using NativeSARSOP

# Define the solver with the desired paramters
sarsop_solver = SARSOPSolver(; max_time=10.0)

# Solve the problem by calling POMDPs.solve. SARSOP will compute the policy and return an `AlphaVectorPolicy`
sarsop_policy = POMDPs.solve(sarsop_solver, quick_crying_baby_pomdp)

# We can query the policy using the `action` function
b = initialstate(quick_crying_baby_pomdp)
a = action(sarsop_policy, b)

@show a
:ignore

Online (POMCP)

For the online solver, we will use Particle Monte Carlo Planning (POMCP). For online solvers, we first define the solver similar to offline solvers. However, when we call POMDPs.solve, we are returned an online plannner. Similar to the offline solver, we can query the policy using the action function and that is when the online solver will compute the action.

using BasicPOMCP

pomcp_solver = POMCPSolver(; c=5.0, tree_queries=1000, rng=MersenneTwister(1))
pomcp_planner = POMDPs.solve(pomcp_solver, quick_crying_baby_pomdp)

b = initialstate(quick_crying_baby_pomdp)
a = action(pomcp_planner, b)

@show a
:ignore

Heuristic Policy

While we often want to use a solver to compute a policy, sometimes we might want to use a heuristic policy. For example, we may want to use a heuristic policy during our rollouts for online solvers or to use as a baseline. In this example, we will define a simple heuristic policy that feeds the baby if our belief of the baby being hungry is greater than 50%, otherwise we will randomly ignore or sing to the baby.

struct HeuristicFeedPolicy{P<:POMDP} <: Policy
    pomdp::P
end

# We need to implement the action function for our policy
function POMDPs.action(policy::HeuristicFeedPolicy, b)
    if pdf(b, :hungry) > 0.5
        return :feed
    else
        return rand([:ignore, :sing])
    end
end

# Let's also define the default updater for our policy
function POMDPs.updater(policy::HeuristicFeedPolicy)
    return DiscreteUpdater(policy.pomdp)
end

heuristic_policy = HeuristicFeedPolicy(quick_crying_baby_pomdp)

# Let's query the policy a few times
b = SparseCat([:sated, :hungry], [0.1, 0.9])
a1 =  action(heuristic_policy, b)

b = SparseCat([:sated, :hungry], [0.9, 0.1])
a2 = action(heuristic_policy, b)

@show [a1, a2]
2-element Vector{Symbol}:
 :feed
 :sing