1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use std::f64;
use environment::Environment;
use environment::Transition;
use environment::{Space, FiniteSpace};
use trainer::OnlineTrainer;
use agent::Agent;
use util::{QFunction, TimePeriod};
#[derive(Debug)]
pub struct QLearner<A: FiniteSpace> {
action_space: A,
gamma: f64,
alpha: f64,
train_period: TimePeriod,
}
impl<T, S: Space, A: FiniteSpace> OnlineTrainer<S, A, T> for QLearner<A>
where T: QFunction<S, A> + Agent<S, A> {
fn train_step(&mut self, agent: &mut T, transition: Transition<S, A>) {
let (state, action, reward, next) = transition;
let mut max_next_val = f64::MIN;
for a in self.action_space.enumerate() {
max_next_val = max_next_val.max(agent.eval(&next, &a));
}
agent.update(&state, &action, reward + self.gamma*max_next_val, self.alpha);
}
fn train(&mut self, agent: &mut T, env: &mut Environment<State=S, Action=A>) {
let mut obs = env.reset();
let mut time_remaining = self.train_period.clone();
while !time_remaining.is_none() {
let action = agent.get_action(&obs.state);
let new_obs = env.step(&action);
self.train_step(agent, (obs.state, action, new_obs.reward, new_obs.state.clone()));
time_remaining = time_remaining.dec(new_obs.done);
obs = if new_obs.done {env.reset()} else {new_obs};
}
}
}
impl<A: FiniteSpace> QLearner<A> {
pub fn new(action_space: A, gamma: f64, alpha: f64, train_period: TimePeriod) -> QLearner<A> {
QLearner {
action_space: action_space,
gamma: gamma,
alpha: alpha,
train_period: train_period
}
}
pub fn default(action_space: A) -> QLearner<A> {
QLearner {
action_space: action_space,
gamma: 0.95,
alpha: 0.1,
train_period: TimePeriod::EPISODES(100)
}
}
pub fn gamma(mut self, gamma: f64) -> QLearner<A> {
self.gamma = gamma;
self
}
pub fn alpha(mut self, alpha: f64) -> QLearner<A> {
self.alpha = alpha;
self
}
pub fn train_period(mut self, train_period: TimePeriod) -> QLearner<A> {
self.train_period = train_period;
self
}
}