mlpack 3.4.2
random_replay.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
13#define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
14
15#include <mlpack/prereqs.hpp>
16#include <cassert>
17
18namespace mlpack {
19namespace rl {
20
43template <typename EnvironmentType>
45{
46 public:
48 using ActionType = typename EnvironmentType::Action;
49
51 using StateType = typename EnvironmentType::State;
52
54 {
57 double reward;
59 bool isEnd;
60 };
61
63 batchSize(0),
64 capacity(0),
65 position(0),
66 full(false),
67 nSteps(0)
68 { /* Nothing to do here. */ }
69
78 RandomReplay(const size_t batchSize,
79 const size_t capacity,
80 const size_t nSteps = 1,
81 const size_t dimension = StateType::dimension) :
82 batchSize(batchSize),
83 capacity(capacity),
84 position(0),
85 full(false),
86 nSteps(nSteps),
87 states(dimension, capacity),
88 actions(capacity),
89 rewards(capacity),
90 nextStates(dimension, capacity),
91 isTerminal(capacity)
92 { /* Nothing to do here. */ }
93
104 void Store(StateType state,
105 ActionType action,
106 double reward,
107 StateType nextState,
108 bool isEnd,
109 const double& discount)
110 {
111 nStepBuffer.push_back({state, action, reward, nextState, isEnd});
112
113 // Single step transition is not ready.
114 if (nStepBuffer.size() < nSteps)
115 return;
116
117 // To keep the queue size fixed to nSteps.
118 if (nStepBuffer.size() > nSteps)
119 nStepBuffer.pop_front();
120
121 // Before moving ahead, lets confirm if our fixed size buffer works.
122 assert(nStepBuffer.size() == nSteps);
123
124 // Make a n-step transition.
125 GetNStepInfo(reward, nextState, isEnd, discount);
126
127 state = nStepBuffer.front().state;
128 action = nStepBuffer.front().action;
129
130 states.col(position) = state.Encode();
131 actions[position] = action;
132 rewards(position) = reward;
133 nextStates.col(position) = nextState.Encode();
134 isTerminal(position) = isEnd;
135 position++;
136 if (position == capacity)
137 {
138 full = true;
139 position = 0;
140 }
141 }
142
151 void GetNStepInfo(double& reward,
152 StateType& nextState,
153 bool& isEnd,
154 const double& discount)
155 {
156 reward = nStepBuffer.back().reward;
157 nextState = nStepBuffer.back().nextState;
158 isEnd = nStepBuffer.back().isEnd;
159
160 // Should start from the second last transition in buffer.
161 for (int i = nStepBuffer.size() - 2; i >= 0; i--)
162 {
163 bool iE = nStepBuffer[i].isEnd;
164 reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
165 if (iE)
166 {
167 nextState = nStepBuffer[i].nextState;
168 isEnd = iE;
169 }
170 }
171 }
172
183 void Sample(arma::mat& sampledStates,
184 std::vector<ActionType>& sampledActions,
185 arma::rowvec& sampledRewards,
186 arma::mat& sampledNextStates,
187 arma::irowvec& isTerminal)
188 {
189 size_t upperBound = full ? capacity : position;
190 arma::uvec sampledIndices = arma::randi<arma::uvec>(
191 batchSize, arma::distr_param(0, upperBound - 1));
192
193 sampledStates = states.cols(sampledIndices);
194 for (size_t t = 0; t < sampledIndices.n_rows; t ++)
195 sampledActions.push_back(actions[sampledIndices[t]]);
196 sampledRewards = rewards.elem(sampledIndices).t();
197 sampledNextStates = nextStates.cols(sampledIndices);
198 isTerminal = this->isTerminal.elem(sampledIndices).t();
199 }
200
206 const size_t& Size()
207 {
208 return full ? capacity : position;
209 }
210
219 void Update(arma::mat /* target */,
220 std::vector<ActionType> /* sampledActions */,
221 arma::mat /* nextActionValues */,
222 arma::mat& /* gradients */)
223 {
224 /* Do nothing for random replay. */
225 }
226
228 const size_t& NSteps() const { return nSteps; }
229
230 private:
232 size_t batchSize;
233
235 size_t capacity;
236
238 size_t position;
239
241 bool full;
242
244 size_t nSteps;
245
247 std::deque<Transition> nStepBuffer;
248
250 arma::mat states;
251
253 std::vector<ActionType> actions;
254
256 arma::rowvec rewards;
257
259 arma::mat nextStates;
260
262 arma::irowvec isTerminal;
263};
264
265} // namespace rl
266} // namespace mlpack
267
268#endif
Implementation of random experience replay.
const size_t & Size()
Get the number of transitions in the memory.
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experiences.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience.
void Update(arma::mat, std::vector< ActionType >, arma::mat, arma::mat &)
Update the priorities of transitions and Update the gradients.
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step.
const size_t & NSteps() const
Get the number of steps for n-step agent.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.
typename EnvironmentType::State StateType
Convenient typedef for state.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.