1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use num::Float;
use num::cast::NumCast;
use rulinalg::matrix::{Matrix, BaseMatrix};
use rulinalg::vector::Vector;
use environment::{Space, Transition};
use trainer::BatchTrainer;
use agent::Agent;
use util::{ParameterizedFunc, FeatureExtractor};
#[derive(Debug)]
pub struct LSPolicyIteration<F: Float> {
gamma: F,
}
impl<F: Float + 'static, S: Space, A: Space, T> BatchTrainer<S, A, T> for LSPolicyIteration<F>
where T: Agent<S, A> + ParameterizedFunc<F> + FeatureExtractor<S, A, F> {
fn train(&mut self, agent: &mut T, transitions: Vec<Transition<S, A>>) {
let num_features = agent.num_features();
let mut mat: Matrix<F> = Matrix::zeros(num_features, num_features);
let mut vec: Matrix<F> = Matrix::zeros(num_features, 1);
let num: F = NumCast::from(transitions.len()).unwrap();
for transition in transitions {
let (state, action, reward, next) = transition;
let next_action = agent.get_action(&next);
let feats = Matrix::new(num_features, 1, agent.extract(&state, &action));
let next_feats = Matrix::new(num_features, 1, agent.extract(&next, &next_action));
let feats_t = feats.clone().transpose();
mat += &feats * &(&feats_t - &next_feats.transpose() * self.gamma);
let reward: F = NumCast::from(reward).unwrap();
vec += &feats * reward;
}
let vec = Vector::new(vec.into_vec());
let weights = (&mat / num).solve(&vec / num).unwrap();
agent.set_params(weights.into_vec());
}
}
impl<F: Float> Default for LSPolicyIteration<F> {
fn default() -> LSPolicyIteration<F> {
LSPolicyIteration {
gamma: NumCast::from(0.99).unwrap()
}
}
}
impl<F: Float> LSPolicyIteration<F> {
pub fn new(gamma: F) -> LSPolicyIteration<F> {
assert!(F::zero() <= gamma && gamma <= F::one(), "elite must be between 0 and 1");
LSPolicyIteration {
gamma: gamma
}
}
pub fn gamma(mut self, gamma: F) -> LSPolicyIteration<F> {
assert!(F::zero() <= gamma && gamma <= F::one(), "elite must be between 0 and 1");
self.gamma = gamma;
self
}
}