mlpack 3.4.2
double_pole_cart.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14#define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
15
16#include <mlpack/prereqs.hpp>
17
18namespace mlpack {
19namespace rl {
20
28{
29 public:
35 class State
36 {
37 public:
41 State() : data(dimension)
42 { /* Nothing to do here. */ }
43
49 State(const arma::colvec& data) : data(data)
50 { /* Nothing to do here */ }
51
53 arma::colvec Data() const { return data; }
55 arma::colvec& Data() { return data; }
56
58 double Position() const { return data[0]; }
60 double& Position() { return data[0]; }
61
63 double Velocity() const { return data[1]; }
65 double& Velocity() { return data[1]; }
66
68 double Angle(const size_t i) const { return data[2 * i]; }
70 double& Angle(const size_t i) { return data[2 * i]; }
71
73 double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
75 double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
76
78 const arma::colvec& Encode() const { return data; }
79
81 static constexpr size_t dimension = 6;
82
83 private:
85 arma::colvec data;
86 };
87
91 class Action
92 {
93 public:
95 {
98 };
99 // To store the action.
101
102 // Track the size of the action space.
103 static const size_t size = 2;
104 };
105
123 DoublePoleCart(const size_t maxSteps = 0,
124 const double m1 = 0.1,
125 const double m2 = 0.01,
126 const double l1 = 0.5,
127 const double l2 = 0.05,
128 const double gravity = 9.8,
129 const double massCart = 1.0,
130 const double forceMag = 10.0,
131 const double tau = 0.02,
132 const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
133 const double xThreshold = 2.4,
134 const double doneReward = 0.0) :
135 maxSteps(maxSteps),
136 m1(m1),
137 m2(m2),
138 l1(l1),
139 l2(l2),
140 gravity(gravity),
141 massCart(massCart),
142 forceMag(forceMag),
143 tau(tau),
144 thetaThresholdRadians(thetaThresholdRadians),
145 xThreshold(xThreshold),
146 doneReward(doneReward),
147 stepsPerformed(0)
148 { /* Nothing to do here */ }
149
159 double Sample(const State& state,
160 const Action& action,
161 State& nextState)
162 {
163 // Update the number of steps performed.
164 stepsPerformed++;
165
166 arma::vec dydx(6, arma::fill::zeros);
167 dydx[0] = state.Velocity();
168 dydx[2] = state.AngularVelocity(1);
169 dydx[4] = state.AngularVelocity(2);
170 Dsdt(state, action, dydx);
171 RK4(state, action, dydx, nextState);
172
173 // Check if the episode has terminated.
174 bool done = IsTerminal(nextState);
175
176 // Do not reward agent if it failed.
177 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
178 return doneReward;
179 else if (done)
180 return 0;
181
186 return 1.0;
187 }
188
197 void Dsdt(const State& state,
198 const Action& action,
199 arma::vec& dydx)
200 {
201 double totalForce = action.action ? forceMag : -forceMag;
202 double totalMass = massCart;
203 double omega1 = state.AngularVelocity(1);
204 double omega2 = state.AngularVelocity(2);
205 double sinTheta1 = std::sin(state.Angle(1));
206 double sinTheta2 = std::sin(state.Angle(2));
207 double cosTheta1 = std::cos(state.Angle(1));
208 double cosTheta2 = std::cos(state.Angle(2));
209
210 // Calculate total effective force.
211 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
212 std::sin(2 * state.Angle(1));
213 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
214 std::sin(2 * state.Angle(2));
215
216 // Calculate total effective mass.
217 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
218 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
219
220 // Calculate acceleration.
221 double xAcc = totalForce / totalMass;
222 dydx[1] = xAcc;
223
224 // Calculate angular acceleration.
225 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
226 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
227 }
228
238 void RK4(const State& state,
239 const Action& action,
240 arma::vec& dydx,
241 State& nextState)
242 {
243 const double hh = tau * 0.5;
244 const double h6 = tau / 6;
245 arma::vec yt(6);
246 arma::vec dyt(6);
247 arma::vec dym(6);
248
249 yt = state.Data() + (hh * dydx);
250 Dsdt(State(yt), action, dyt);
251 dyt[0] = yt[1];
252 dyt[2] = yt[3];
253 dyt[4] = yt[5];
254 yt = state.Data() + (hh * dyt);
255
256 Dsdt(State(yt), action, dym);
257 dym[0] = yt[1];
258 dym[2] = yt[3];
259 dym[4] = yt[5];
260 yt = state.Data() + (tau * dym);
261 dym += dyt;
262
263 Dsdt(State(yt), action, dyt);
264 dyt[0] = yt[1];
265 dyt[2] = yt[3];
266 dyt[4] = yt[5];
267 nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
268 }
269
278 double Sample(const State& state, const Action& action)
279 {
280 State nextState;
281 return Sample(state, action, nextState);
282 }
283
290 {
291 stepsPerformed = 0;
292 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
293 }
294
301 bool IsTerminal(const State& state) const
302 {
303 if (maxSteps != 0 && stepsPerformed >= maxSteps)
304 {
305 Log::Info << "Episode terminated due to the maximum number of steps"
306 "being taken.";
307 return true;
308 }
309 if (std::abs(state.Position()) > xThreshold)
310 {
311 Log::Info << "Episode terminated due to cart crossing threshold";
312 return true;
313 }
314 if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
315 std::abs(state.Angle(2)) > thetaThresholdRadians)
316 {
317 Log::Info << "Episode terminated due to pole falling";
318 return true;
319 }
320 return false;
321 }
322
324 size_t StepsPerformed() const { return stepsPerformed; }
325
327 size_t MaxSteps() const { return maxSteps; }
329 size_t& MaxSteps() { return maxSteps; }
330
331 private:
333 size_t maxSteps;
334
336 double m1;
337
339 double m2;
340
342 double l1;
343
345 double l2;
346
348 double gravity;
349
351 double massCart;
352
354 double forceMag;
355
357 double tau;
358
360 double thetaThresholdRadians;
361
363 double xThreshold;
364
366 double doneReward;
367
369 size_t stepsPerformed;
370};
371
372} // namespace rl
373} // namespace mlpack
374
375#endif
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Implementation of action of Double Pole Cart.
Implementation of the state of Double Pole Cart.
const arma::colvec & Encode() const
Encode the state to a vector..
arma::colvec Data() const
Get the internal representation of the state.
double & Velocity()
Modify the velocity of the cart.
double Velocity() const
Get the velocity of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
State()
Construct a state instance.
State(const arma::colvec &data)
Construct a state instance from given data.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Position()
Modify the position of the cart.
double Position() const
Get the position of the cart.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
static constexpr size_t dimension
Dimension of the encoded state.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Double Pole Cart Balancing task.
DoublePoleCart(const size_t maxSteps=0, const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Double Pole Cart instance using the given constants.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.