mlpack 3.4.2
continuous_double_pole_cart.hpp
Go to the documentation of this file.
1
14#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
15#define MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
16
17#include <mlpack/prereqs.hpp>
18
19namespace mlpack {
20namespace rl {
21
29{
30 public:
36 class State
37 {
38 public:
42 State() : data(dimension)
43 { /* Nothing to do here. */ }
44
50 State(const arma::colvec& data) : data(data)
51 { /* Nothing to do here */ }
52
54 arma::colvec Data() const { return data; }
56 arma::colvec& Data() { return data; }
57
59 double Position() const { return data[0]; }
61 double& Position() { return data[0]; }
62
64 double Velocity() const { return data[1]; }
66 double& Velocity() { return data[1]; }
67
69 double Angle(const size_t i) const { return data[2 * i]; }
71 double& Angle(const size_t i) { return data[2 * i]; }
72
74 double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
76 double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
77
79 const arma::colvec& Encode() const { return data; }
80
82 static constexpr size_t dimension = 6;
83
84 private:
86 arma::colvec data;
87 };
88
92 struct Action
93 {
94 double action[1];
95 // Storing degree of freedom
96 const int size = 1;
97 };
98
116 ContinuousDoublePoleCart(const double m1 = 0.1,
117 const double m2 = 0.01,
118 const double l1 = 0.5,
119 const double l2 = 0.05,
120 const double gravity = 9.8,
121 const double massCart = 1.0,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 36 * 2 *
125 3.1416 / 360,
126 const double xThreshold = 2.4,
127 const double doneReward = 0.0,
128 const size_t maxSteps = 0) :
129 m1(m1),
130 m2(m2),
131 l1(l1),
132 l2(l2),
133 gravity(gravity),
134 massCart(massCart),
135 forceMag(forceMag),
136 tau(tau),
137 thetaThresholdRadians(thetaThresholdRadians),
138 xThreshold(xThreshold),
139 doneReward(doneReward),
140 maxSteps(maxSteps),
141 stepsPerformed(0)
142 { /* Nothing to do here */ }
143
153 double Sample(const State& state,
154 const Action& action,
155 State& nextState)
156 {
157 // Update the number of steps performed.
158 stepsPerformed++;
159
160 arma::vec dydx(6, arma::fill::zeros);
161 dydx[0] = state.Velocity();
162 dydx[2] = state.AngularVelocity(1);
163 dydx[4] = state.AngularVelocity(2);
164 Dsdt(state, action, dydx);
165 RK4(state, action, dydx, nextState);
166
167 // Check if the episode has terminated.
168 bool done = IsTerminal(nextState);
169
170 // Do not reward agent if it failed.
171 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
172 return doneReward;
173 else if (done)
174 return 0;
175
180 return 1.0;
181 }
182
191 void Dsdt(const State& state,
192 const Action& action,
193 arma::vec& dydx)
194 {
195 double totalForce = action.action[0];
196 double totalMass = massCart;
197 double omega1 = state.AngularVelocity(1);
198 double omega2 = state.AngularVelocity(2);
199 double sinTheta1 = std::sin(state.Angle(1));
200 double sinTheta2 = std::sin(state.Angle(2));
201 double cosTheta1 = std::cos(state.Angle(1));
202 double cosTheta2 = std::cos(state.Angle(2));
203
204 // Calculate total effective force.
205 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
206 std::sin(2 * state.Angle(1));
207 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
208 std::sin(2 * state.Angle(2));
209
210 // Calculate total effective mass.
211 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
212 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
213
214 // Calculate acceleration.
215 double xAcc = totalForce / totalMass;
216 dydx[1] = xAcc;
217
218 // Calculate angular acceleration.
219 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
220 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
221 }
222
232 void RK4(const State& state,
233 const Action& action,
234 arma::vec& dydx,
235 State& nextState)
236 {
237 const double hh = tau * 0.5;
238 const double h6 = tau / 6;
239 arma::vec yt(6);
240 arma::vec dyt(6);
241 arma::vec dym(6);
242
243 yt = state.Data() + (hh * dydx);
244 Dsdt(State(yt), action, dyt);
245 dyt[0] = yt[1];
246 dyt[2] = yt[3];
247 dyt[4] = yt[5];
248 yt = state.Data() + (hh * dyt);
249
250 Dsdt(State(yt), action, dym);
251 dym[0] = yt[1];
252 dym[2] = yt[3];
253 dym[4] = yt[5];
254 yt = state.Data() + (tau * dym);
255 dym += dyt;
256
257 Dsdt(State(yt), action, dyt);
258 dyt[0] = yt[1];
259 dyt[2] = yt[3];
260 dyt[4] = yt[5];
261 nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
262 }
263
272 double Sample(const State& state, const Action& action)
273 {
274 State nextState;
275 return Sample(state, action, nextState);
276 }
277
284 {
285 stepsPerformed = 0;
286 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
287 }
288
295 bool IsTerminal(const State& state) const
296 {
297 if (maxSteps != 0 && stepsPerformed >= maxSteps)
298 {
299 Log::Info << "Episode terminated due to the maximum number of steps"
300 "being taken.";
301 return true;
302 }
303 if (std::abs(state.Position()) > xThreshold)
304 {
305 Log::Info << "Episode terminated due to cart crossing threshold";
306 return true;
307 }
308 if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
309 std::abs(state.Angle(2)) > thetaThresholdRadians)
310 {
311 Log::Info << "Episode terminated due to pole falling";
312 return true;
313 }
314 return false;
315 }
316
318 size_t StepsPerformed() const { return stepsPerformed; }
319
321 size_t MaxSteps() const { return maxSteps; }
323 size_t& MaxSteps() { return maxSteps; }
324
325 private:
327 double m1;
328
330 double m2;
331
333 double l1;
334
336 double l2;
337
339 double gravity;
340
342 double massCart;
343
345 double forceMag;
346
348 double tau;
349
351 double thetaThresholdRadians;
352
354 double xThreshold;
355
357 double doneReward;
358
360 size_t maxSteps;
361
363 size_t stepsPerformed;
364};
365
366} // namespace rl
367} // namespace mlpack
368
369#endif
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Implementation of the state of Continuous Double Pole Cart.
const arma::colvec & Encode() const
Encode the state to a vector..
arma::colvec Data() const
Get the internal representation of the state.
double & Velocity()
Modify the velocity of the cart.
double Velocity() const
Get the velocity of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
State(const arma::colvec &data)
Construct a state instance from given data.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Position()
Modify the position of the cart.
double Position() const
Get the position of the cart.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
static constexpr size_t dimension
Dimension of the encoded state.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Continuous Double Pole Cart Balancing task.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Continuous Double Pole Cart instance.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
ContinuousDoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
double Sample(const State &state, const Action &action)
Dynamics of Continuous Double Pole Cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of action of Continuous Double Pole Cart.