1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use rand::{Rng, thread_rng};
use util::Chooser;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Uniform;
impl<T: Clone> Chooser<T> for Uniform {
fn choose(&self, choices: &Vec<T>, _: Vec<f64>) -> T {
let mut rng = thread_rng();
rng.choose(&choices).unwrap().clone()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Softmax {
temp: f64
}
impl Default for Softmax {
fn default() -> Softmax {
Softmax{temp: 1.0}
}
}
impl<T: Clone> Chooser<T> for Softmax {
fn choose(&self, choices: &Vec<T>, weights: Vec<f64>) -> T {
let mut total = 0.0;
let new_weights: Vec<_> = weights.into_iter()
.map(|w| {
let u = (w/self.temp).exp();
total += u;
u
})
.collect();
let mut rng = thread_rng();
let mut index = 0;
if total == 0.0 {
return rng.choose(&choices).unwrap().clone()
}
let mut choice = rng.gen_range(0.0, total);
while choice > new_weights[index] {
choice -= new_weights[index];
index = index + 1;
}
choices[index].clone()
}
}
impl Softmax {
pub fn new(temp: f64) -> Softmax {
Softmax {
temp: temp
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Weighted;
impl<T: Clone> Chooser<T> for Weighted {
fn choose(&self, choices: &Vec<T>, weights: Vec<f64>) -> T {
let total = weights.iter().sum();
let mut rng = thread_rng();
if total == 0.0 {
return rng.choose(&choices).unwrap().clone();
}
let mut index = 0;
let mut choice = rng.gen_range(0.0, total);
while choice > weights[index] {
choice -= weights[index];
index += 1;
}
choices[index].clone()
}
}