mlpack 3.4.2
sumtree.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_SUMTREE_HPP
14#define MLPACK_METHODS_RL_SUMTREE_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace rl {
20
31template<typename T>
33{
34 public:
38 SumTree() : capacity(0)
39 { /* Nothing to do here. */ }
40
46 SumTree(const size_t capacity) : capacity(capacity)
47 {
48 element = std::vector<T>(2 * capacity);
49 }
50
57 void Set(size_t idx, const T value)
58 {
59 idx += capacity;
60 element[idx] = value;
61 idx /= 2;
62 while (idx >= 1)
63 {
64 element[idx] = element[2 * idx] + element[2 * idx + 1];
65 idx /= 2;
66 }
67 }
68
75 void BatchUpdate(const arma::ucolvec& indices, const arma::Col<T>& data)
76 {
77 for (size_t i = 0; i < indices.n_rows; ++i)
78 {
79 element[indices[i] + capacity] = data[i];
80 }
81 // update the total tree with bottom-up technique.
82 for (size_t i = capacity - 1; i > 0; i--)
83 {
84 element[i] = element[2 * i] + element[2 * i + 1];
85 }
86 }
87
93 T Get(size_t idx)
94 {
95 idx += capacity;
96 return element[idx];
97 }
98
108 T SumHelper(const size_t start,
109 const size_t end,
110 const size_t node,
111 const size_t nodeStart,
112 const size_t nodeEnd)
113 {
114 if (start == nodeStart && end == nodeEnd)
115 {
116 return element[node];
117 }
118 size_t mid = (nodeStart + nodeEnd) / 2;
119 if (end <= mid)
120 {
121 return SumHelper(start, end, 2 * node, nodeStart, mid);
122 }
123 else
124 {
125 if (mid + 1 <= start)
126 {
127 return SumHelper(start, end, 2 * node + 1, mid + 1 , nodeEnd);
128 }
129 else
130 {
131 return SumHelper(start, mid, 2 * node, nodeStart, mid) +
132 SumHelper(mid + 1, end, 2 * node + 1, mid + 1 , nodeEnd);
133 }
134 }
135 }
136
143 T Sum(const size_t start, size_t end)
144 {
145 end -= 1;
146 return SumHelper(start, end, 1, 0, capacity - 1);
147 }
148
152 T Sum()
153 {
154 return Sum(0, capacity);
155 }
156
163 size_t FindPrefixSum(T mass)
164 {
165 size_t idx = 1;
166 while (idx < capacity)
167 {
168 if (element[2 * idx] > mass)
169 {
170 idx = 2 * idx;
171 }
172 else
173 {
174 mass -= element[2 * idx];
175 idx = 2 * idx + 1;
176 }
177 }
178 return idx - capacity;
179 }
180
181 private:
183 size_t capacity;
184
186 std::vector<T> element;
187};
188
189} // namespace rl
190} // namespace mlpack
191
192#endif
Implementation of SumTree.
Definition: sumtree.hpp:33
T Get(size_t idx)
Get the data array with idx.
Definition: sumtree.hpp:93
SumTree(const size_t capacity)
Construct an instance of SumTree class.
Definition: sumtree.hpp:46
T Sum()
Shortcut for calculating the sum of whole array.
Definition: sumtree.hpp:152
void BatchUpdate(const arma::ucolvec &indices, const arma::Col< T > &data)
Update the data with batch rather loop over the indices with set method.
Definition: sumtree.hpp:75
size_t FindPrefixSum(T mass)
Find the highest index idx in the array such that sum(arr[0] + arr[1] + ... + arr[idx]) <= mass.
Definition: sumtree.hpp:163
T SumHelper(const size_t start, const size_t end, const size_t node, const size_t nodeStart, const size_t nodeEnd)
Help function for the sum function.
Definition: sumtree.hpp:108
void Set(size_t idx, const T value)
Set the data array with idx.
Definition: sumtree.hpp:57
T Sum(const size_t start, size_t end)
Calculate the sum of contiguous subsequence of the array.
Definition: sumtree.hpp:143
SumTree()
Default constructor.
Definition: sumtree.hpp:38
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.