1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//! Utilities Module

pub mod table;
pub mod chooser;
pub mod approx;
pub mod feature;
pub mod graddesc;

mod metric;

use std::fmt::Debug;

use num::Num;
use num::Float;

use environment::Space;

/// A function that evaluates its input by making use of some parameters
pub trait ParameterizedFunc<T: Num> {
	/// Returns number of parameters used by the function
	fn num_params(&self) -> usize;
	/// Returns the parameters used by the function
	fn get_params(&self) -> Vec<T>;
	/// Changes the parameters used by the function
	fn set_params(&mut self, params: Vec<T>);
}

/// A differentiable function taking in (state, action) pairs 
pub trait DifferentiableFunc<S: Space, A: Space, T: Num> : ParameterizedFunc<T> {
	/// Calculates the gradient of the output with respect to this function's parameters
	fn get_grad(&self, state: &S::Element, action: &A::Element) -> Vec<T>;
	/// Calculates the result of calling function on given input
	fn calculate(&self, state: &S::Element, action: &A::Element) -> T;
}

/// A function taking in (state, action) pairs whose log can be differentiated
pub trait LogDiffFunc<S: Space, A: Space, T: Num> : ParameterizedFunc<T> {
	/// The gradient of the log of the output with respect to the parameters
	fn log_grad(&self, state: &S::Element, action: &A::Element) -> Vec<T>;
}

/// Calculates gradient steps
pub trait GradientDescAlgo<F: Float> {
	/// Calculates local step for minimizing function
	fn calculate(&mut self, grad: Vec<F>, lr: F) -> Vec<F>;
}

/// Represents something that extracts features from state-action pairs
pub trait FeatureExtractor<S: Space, A: Space, F: Float> {
	/// Number of features that can be calculated
	fn num_features(&self) -> usize;
	/// Vector containg the values of all the features for this state
	fn extract(&self, state: &S::Element, action: &A::Element) -> Vec<F>;
}

/// QFunction Trait
///
/// Represents a function Q: S x A -> R that takes in a (state, action) pair
/// and returns the value of that pair
pub trait QFunction<S: Space, A: Space> : Debug {
	/// Evaluate the function on the given state and action
	fn eval(&self, state: &S::Element, action: &A::Element) -> f64;
	/// Update the function using the given information (alpha is learning rate)
	fn update(&mut self, state: &S::Element, action: &A::Element, new_val: f64, alpha: f64);
}

/// VFunction Trait
///
/// Represents a function V: S -> R that takes in a state and returns its value
pub trait VFunction<S: Space> : Debug {
	/// Evaluate the function on the given state
	fn eval(&self, state: &S::Element) -> f64;
	/// Update the function using the given information (alpha is learning rate)
	fn update(&mut self, state: &S::Element, new_val: f64, alpha: f64);
}

/// Choose Trait
///
/// Represents a way to randomly choose an element of a list given some weights
pub trait Chooser<T> : Debug {
	/// returns an element of choices
	fn choose(&self, choices: &Vec<T>, weights: Vec<f64>) -> T;
}

/// A real-valued feature of elements of some state space
pub trait Feature<S: Space, F: Float> : Debug {
	/// Extracts some real-valued feature from a given state
	fn extract(&self, state: &S::Element) -> F;
	/// Creates a cloned trait object of self
	fn box_clone(&self) -> Box<Feature<S, F>>;
}

impl<F: Float, S: Space> Clone for Box<Feature<S, F>> {
	fn clone(&self) -> Self {
		self.box_clone()
	}
}

/// A type with a notion of distance
/// The distance function should satisfy the triangle inequality (and the other [metric](https://www.wikiwand.com/en/Metric_(mathematics)) properties)
///
/// d(x,z) <= d(x,y) + d(y,z)
pub trait Metric {
	/// Returns the distance between x and y
	fn dist(x: &Self, y: &Self) -> f64 {
		Metric::dist2(x, y).sqrt()
	}
	/// Returns the squared distance between x and y
	fn dist2(x: &Self, y: &Self) -> f64;
}

/// Some length of time experienced by an agent
#[derive(Debug, Clone)]
pub enum TimePeriod {
	/// A time period stored as a number of episodes
	EPISODES(usize),
	/// A time period stored as a number of individual timesteps
	TIMESTEPS(usize),
	/// Time period ends when first or second one ends
	OR(Box<TimePeriod>, Box<TimePeriod>),
}

impl TimePeriod {
	/// Returns whether or not self represents an empty time period
	pub fn is_none(&self) -> bool {
		match *self {
			TimePeriod::EPISODES(x) => x == 0,
			TimePeriod::TIMESTEPS(x) => x == 0,
			TimePeriod::OR(ref a, ref b) => a.is_none() || b.is_none(),
		}
	}
	/// Returns the time period remaing after one time step
	pub fn dec(&self, done: bool) -> TimePeriod {
		if self.is_none() {
			self.clone()
		} else {
			match *self {
				TimePeriod::EPISODES(x) => TimePeriod::EPISODES(if done {x-1} else {x}),
				TimePeriod::TIMESTEPS(x) => TimePeriod::TIMESTEPS(x-1),
				TimePeriod::OR(ref a, ref b) => TimePeriod::OR(Box::new(a.dec(done)), Box::new(b.dec(done))),
			}
		}
	}
}