1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
//! Chooser Module

use rand::{Rng, thread_rng};

use util::Chooser;

/// Represents a Chooser that picks each with equal probability
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Uniform;

impl<T: Clone> Chooser<T> for Uniform {
	fn choose(&self, choices: &Vec<T>, _: Vec<f64>) -> T {
		let mut rng = thread_rng();
		rng.choose(&choices).unwrap().clone()
	}
}

/// Represents a Chooser that picks each element with probability according to a softmax distrobution
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Softmax {
	temp: f64
}

impl Default for Softmax {
	/// Creates a Softmax with temperature 1.0
	fn default() -> Softmax {
		Softmax{temp: 1.0}
	}
}

impl<T: Clone> Chooser<T> for Softmax {
	fn choose(&self, choices: &Vec<T>, weights: Vec<f64>) -> T {
		let mut total = 0.0;
		let new_weights: Vec<_> = weights.into_iter()
								 		 .map(|w| {
								 			let u = (w/self.temp).exp();
										 	total += u;
										 	u
										 })
										 .collect();
		let mut rng = thread_rng();
		let mut index = 0;

		if total == 0.0 {
			return rng.choose(&choices).unwrap().clone()
		}

		let mut choice = rng.gen_range(0.0, total);
		while choice > new_weights[index] {
			choice -= new_weights[index];
			index = index + 1;
		}
		choices[index].clone()
	}
}

impl Softmax {
	/// Creates a new Softmax with the given temp
	pub fn new(temp: f64) -> Softmax {
		Softmax {
			temp: temp
		}
	}
}

/// Chooses elements with probabilities proportional to their weights
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Weighted; // TODO: Come up with a better name

impl<T: Clone> Chooser<T> for Weighted {
	fn choose(&self, choices: &Vec<T>, weights: Vec<f64>) -> T {
		let total = weights.iter().sum();

		let mut rng = thread_rng();
		if total == 0.0 {
			return rng.choose(&choices).unwrap().clone();
		}

		let mut index = 0;
		let mut choice = rng.gen_range(0.0, total);
		while choice > weights[index] {
			choice -= weights[index];
			index += 1;
		}
		choices[index].clone()
	}
}