1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use std::f64;
use environment::Environment;
use environment::Transition;
use environment::Space;
use trainer::OnlineTrainer;
use agent::Agent;
use util::{QFunction, TimePeriod};
#[derive(Debug)]
pub struct SARSALearner {
gamma: f64,
alpha: f64,
train_period: TimePeriod,
}
impl<T, S: Space, A: Space> OnlineTrainer<S, A, T> for SARSALearner
where T: QFunction<S, A> + Agent<S, A> {
fn train_step(&mut self, agent: &mut T, transition: Transition<S, A>) {
let (state, action, reward, next) = transition;
let next_action = agent.get_action(&next);
let next_val = agent.eval(&next, &next_action);
agent.update(&state, &action, reward + self.gamma*next_val, self.alpha);
}
fn train(&mut self, agent: &mut T, env: &mut Environment<State=S, Action=A>) {
let mut obs = env.reset();
let mut time_remaining = self.train_period.clone();
while !time_remaining.is_none() {
let action = agent.get_action(&obs.state);
let new_obs = env.step(&action);
self.train_step(agent, (obs.state, action, new_obs.reward, new_obs.state.clone()));
time_remaining = time_remaining.dec(new_obs.done);
obs = if new_obs.done {env.reset()} else {new_obs};
}
}
}
impl Default for SARSALearner {
fn default() -> SARSALearner {
SARSALearner {
gamma: 0.95,
alpha: 0.1,
train_period: TimePeriod::EPISODES(100)
}
}
}
impl SARSALearner {
pub fn new(gamma: f64, alpha: f64, train_period: TimePeriod) -> SARSALearner {
SARSALearner {
gamma: gamma,
alpha: alpha,
train_period: train_period
}
}
pub fn gamma(mut self, gamma: f64) -> SARSALearner {
self.gamma = gamma;
self
}
pub fn alpha(mut self, alpha: f64) -> SARSALearner {
self.alpha = alpha;
self
}
pub fn train_period(mut self, train_period: TimePeriod) -> SARSALearner {
self.train_period = train_period;
self
}
}