1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use std::f64;
use environment::Transition;
use environment::{Space, FiniteSpace};
use trainer::BatchTrainer;
use agent::Agent;
use util::QFunction;
#[derive(Debug)]
pub struct FittedQIteration<A: FiniteSpace> {
actions: Vec<A::Element>,
gamma: f64,
alpha: f64,
iters: usize,
}
impl<S: Space, A: FiniteSpace, T> BatchTrainer<S, A, T> for FittedQIteration<A>
where T: QFunction<S, A> + Agent<S, A> {
fn train(&mut self, agent: &mut T, transitions: Vec<Transition<S, A>>) {
for _ in 0..self.iters {
let mut patterns = Vec::with_capacity(transitions.len());
for &(ref s0, ref a, r, ref s1) in &transitions {
let mut max_next_val = f64::MIN;
for a in &self.actions {
max_next_val = max_next_val.max(agent.eval(s1, a));
}
let target = r + self.gamma*max_next_val;
patterns.push((s0, a, target));
}
for (s, a, q) in patterns {
agent.update(s, a, q, self.alpha);
}
}
}
}
impl<A: FiniteSpace> FittedQIteration<A> {
pub fn new(action_space: A, gamma: f64, alpha: f64, iters: usize) -> FittedQIteration<A> {
FittedQIteration {
actions: action_space.enumerate(),
gamma: gamma,
alpha: alpha,
iters: iters
}
}
pub fn default(action_space: A) -> FittedQIteration<A> {
FittedQIteration {
actions: action_space.enumerate(),
gamma: 0.95,
alpha: 0.1,
iters: 10,
}
}
pub fn gamma(mut self, gamma: f64) -> FittedQIteration<A> {
self.gamma = gamma;
self
}
pub fn alpha(mut self, alpha: f64) -> FittedQIteration<A> {
self.alpha = alpha;
self
}
pub fn iters(mut self, iters: usize) -> FittedQIteration<A> {
self.iters = iters;
self
}
}