mlpack 3.4.2
prioritized_replay.hpp
Go to the documentation of this file.
1
12#ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
13#define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
14
15#include <mlpack/prereqs.hpp>
16#include "sumtree.hpp"
17
18namespace mlpack {
19namespace rl {
20
38template <typename EnvironmentType>
40{
41 public:
43 using ActionType = typename EnvironmentType::Action;
44
46 using StateType = typename EnvironmentType::State;
47
49 {
52 double reward;
54 bool isEnd;
55 };
56
61 batchSize(0),
62 capacity(0),
63 position(0),
64 full(false),
65 alpha(0),
66 maxPriority(0),
67 initialBeta(0),
68 beta(0),
69 replayBetaIters(0),
70 nSteps(0)
71 { /* Nothing to do here. */ }
72
82 PrioritizedReplay(const size_t batchSize,
83 const size_t capacity,
84 const double alpha,
85 const size_t nSteps = 1,
86 const size_t dimension = StateType::dimension) :
87 batchSize(batchSize),
88 capacity(capacity),
89 position(0),
90 full(false),
91 alpha(alpha),
92 maxPriority(1.0),
93 initialBeta(0.6),
94 replayBetaIters(10000),
95 nSteps(nSteps),
96 states(dimension, capacity),
97 actions(capacity),
98 rewards(capacity),
99 nextStates(dimension, capacity),
100 isTerminal(capacity)
101 {
102 size_t size = 1;
103 while (size < capacity)
104 {
105 size *= 2;
106 }
107
108 beta = initialBeta;
109 idxSum = SumTree<double>(size);
110 }
111
122 void Store(StateType state,
123 ActionType action,
124 double reward,
125 StateType nextState,
126 bool isEnd,
127 const double& discount)
128 {
129 nStepBuffer.push_back({state, action, reward, nextState, isEnd});
130
131 // Single step transition is not ready.
132 if (nStepBuffer.size() < nSteps)
133 return;
134
135 // To keep the queue size fixed to nSteps.
136 if (nStepBuffer.size() > nSteps)
137 nStepBuffer.pop_front();
138
139 // Before moving ahead, lets confirm if our fixed size buffer works.
140 assert(nStepBuffer.size() == nSteps);
141
142 // Make a n-step transition.
143 GetNStepInfo(reward, nextState, isEnd, discount);
144
145 state = nStepBuffer.front().state;
146 action = nStepBuffer.front().action;
147 states.col(position) = state.Encode();
148 actions[position] = action;
149 rewards(position) = reward;
150 nextStates.col(position) = nextState.Encode();
151 isTerminal(position) = isEnd;
152
153 idxSum.Set(position, maxPriority * alpha);
154
155 position++;
156 if (position == capacity)
157 {
158 full = true;
159 position = 0;
160 }
161 }
162
171 void GetNStepInfo(double& reward,
172 StateType& nextState,
173 bool& isEnd,
174 const double& discount)
175 {
176 reward = nStepBuffer.back().reward;
177 nextState = nStepBuffer.back().nextState;
178 isEnd = nStepBuffer.back().isEnd;
179
180 // Should start from the second last transition in buffer.
181 for (int i = nStepBuffer.size() - 2; i >= 0; i--)
182 {
183 bool iE = nStepBuffer[i].isEnd;
184 reward = nStepBuffer[i].reward + discount * reward * (1 - iE);
185 if (iE)
186 {
187 nextState = nStepBuffer[i].nextState;
188 isEnd = iE;
189 }
190 }
191 }
192
198 arma::ucolvec SampleProportional()
199 {
200 arma::ucolvec idxes(batchSize);
201 double totalSum = idxSum.Sum(0, (full ? capacity : position));
202 double sumPerRange = totalSum / batchSize;
203 for (size_t bt = 0; bt < batchSize; bt++)
204 {
205 const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
206 idxes(bt) = idxSum.FindPrefixSum(mass);
207 }
208 return idxes;
209 }
210
221 void Sample(arma::mat& sampledStates,
222 std::vector<ActionType>& sampledActions,
223 arma::rowvec& sampledRewards,
224 arma::mat& sampledNextStates,
225 arma::irowvec& isTerminal)
226 {
227 sampledIndices = SampleProportional();
228 BetaAnneal();
229
230 sampledStates = states.cols(sampledIndices);
231 for (size_t t = 0; t < sampledIndices.n_rows; t ++)
232 sampledActions.push_back(actions[sampledIndices[t]]);
233 sampledRewards = rewards.elem(sampledIndices).t();
234 sampledNextStates = nextStates.cols(sampledIndices);
235 isTerminal = this->isTerminal.elem(sampledIndices).t();
236
237 // Calculate the weights of sampled transitions.
238
239 size_t numSample = full ? capacity : position;
240 weights = arma::rowvec(sampledIndices.n_rows);
241
242 for (size_t i = 0; i < sampledIndices.n_rows; ++i)
243 {
244 double p_sample = idxSum.Get(sampledIndices(i)) / idxSum.Sum();
245 weights(i) = pow(numSample * p_sample, -beta);
246 }
247 weights /= weights.max();
248 }
249
256 void UpdatePriorities(arma::ucolvec& indices, arma::colvec& priorities)
257 {
258 arma::colvec alphaPri = alpha * priorities;
259 maxPriority = std::max(maxPriority, arma::max(priorities));
260 idxSum.BatchUpdate(indices, alphaPri);
261 }
262
268 const size_t& Size()
269 {
270 return full ? capacity : position;
271 }
272
277 {
278 beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
279 }
280
289 void Update(arma::mat target,
290 std::vector<ActionType> sampledActions,
291 arma::mat nextActionValues,
292 arma::mat& gradients)
293 {
294 arma::colvec tdError(target.n_cols);
295 for (size_t i = 0; i < target.n_cols; i ++)
296 {
297 tdError(i) = nextActionValues(sampledActions[i].action, i) -
298 target(sampledActions[i].action, i);
299 }
300 tdError = arma::abs(tdError);
301 UpdatePriorities(sampledIndices, tdError);
302
303 // Update the gradient
304 gradients = arma::mean(weights) * gradients;
305 }
306
308 const size_t& NSteps() const { return nSteps; }
309
310 private:
312 size_t batchSize;
313
315 size_t capacity;
316
318 size_t position;
319
321 bool full;
322
325 double alpha;
326
328 double maxPriority;
329
331 double initialBeta;
332
334 double beta;
335
337 size_t replayBetaIters;
338
340 SumTree<double> idxSum;
341
343 arma::ucolvec sampledIndices;
344
346 arma::rowvec weights;
347
349 size_t nSteps;
350
352 std::deque<Transition> nStepBuffer;
353
355 arma::mat states;
356
358 std::vector<ActionType> actions;
359
361 arma::rowvec rewards;
362
364 arma::mat nextStates;
365
367 arma::irowvec isTerminal;
368};
369
370} // namespace rl
371} // namespace mlpack
372
373#endif
Implementation of prioritized experience replay.
arma::ucolvec SampleProportional()
Sample some experience according to their priorities.
const size_t & Size()
Get the number of transitions in the memory.
void BetaAnneal()
Annealing the beta.
PrioritizedReplay(const size_t batchSize, const size_t capacity, const double alpha, const size_t nSteps=1, const size_t dimension=StateType::dimension)
Construct an instance of prioritized experience replay class.
PrioritizedReplay()
Default constructor.
void UpdatePriorities(arma::ucolvec &indices, arma::colvec &priorities)
Update priorities of sampled transitions.
void Sample(arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
Sample some experience according to their priorities.
void Update(arma::mat target, std::vector< ActionType > sampledActions, arma::mat nextActionValues, arma::mat &gradients)
Update the priorities of transitions and Update the gradients.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
void Store(StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
Store the given experience and set the priorities for the given experience.
void GetNStepInfo(double &reward, StateType &nextState, bool &isEnd, const double &discount)
Get the reward, next state and terminal boolean for nth step.
const size_t & NSteps() const
Get the number of steps for n-step agent.
typename EnvironmentType::State StateType
Convenient typedef for state.
Implementation of SumTree.
Definition: sumtree.hpp:33
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f$R_ alpha(t)\f$ of a node \f$t\f$ is given by
Definition: det.txt:344
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.