1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
//! Table Module

use std::collections::HashMap;
use std::hash::Hash;
use std::f64;

use environment::FiniteSpace;

use util::{QFunction, VFunction};

/// QTable
///
/// Represents a QFunction implemented using a table
/// The values of all (state, action) pairs are stored in a table
#[derive(Debug, Clone)]
pub struct QTable<S: FiniteSpace, A: FiniteSpace>
	where S::Element: Hash + Eq, A::Element: Hash + Eq {
	map: HashMap<(S::Element, A::Element), f64>
}

impl<S: FiniteSpace, A: FiniteSpace> QFunction<S, A> for QTable<S, A> 
	where S::Element: Hash + Eq, A::Element: Hash + Eq {
	fn eval(&self, state: &S::Element, action: &A::Element) -> f64 {
		if self.map.contains_key(&(state.clone(), action.clone())) {
			self.map[&(state.clone(), action.clone())]
		} else {
			0.0
		}
	}
	fn update(&mut self, state: &S::Element, action: &A::Element, new_val: f64, alpha: f64) {
		let old_val = self.eval(state, action);
		self.map.insert((state.clone(), action.clone()), old_val + alpha*(new_val - old_val));
	}
}

impl<S: FiniteSpace, A: FiniteSpace> QTable<S, A> 
	where S::Element: Hash + Eq, A::Element: Hash + Eq {
	/// Returns a new QTable where all values are initialized to 0
	pub fn new() -> QTable<S, A> {
		QTable {
			map: HashMap::new()
		}
	}
}

/// VTable
///
/// Represents a VFunction implemented using a table
/// The values of all states are stored in a table
#[derive(Debug, Clone)]
pub struct VTable<S: FiniteSpace> where S::Element: Hash + Eq {
	map: HashMap<S::Element, f64>
}

impl<S: FiniteSpace> VFunction<S> for VTable<S> where S::Element: Hash + Eq {
	fn eval(&self, state: &S::Element) -> f64 {
		//*self.map.entry(state).or_insert(0.0)
		if self.map.contains_key(state) {
			self.map[state]
		} else {
			0.0
		}
	}
	fn update(&mut self, state: &S::Element, new_val: f64, alpha: f64) {
		let old_val = self.eval(state);
		self.map.insert(state.clone(), old_val + alpha*(new_val - old_val));
	}
}