mlpack 3.4.2
cart_pole.hpp
Go to the documentation of this file.
1
15#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
16#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
17
18#include <mlpack/prereqs.hpp>
19
20namespace mlpack {
21namespace rl {
22
27{
28 public:
33 class State
34 {
35 public:
39 State() : data(dimension)
40 { /* Nothing to do here. */ }
41
47 State(const arma::colvec& data) : data(data)
48 { /* Nothing to do here */ }
49
51 arma::colvec& Data() { return data; }
52
54 double Position() const { return data[0]; }
56 double& Position() { return data[0]; }
57
59 double Velocity() const { return data[1]; }
61 double& Velocity() { return data[1]; }
62
64 double Angle() const { return data[2]; }
66 double& Angle() { return data[2]; }
67
69 double AngularVelocity() const { return data[3]; }
71 double& AngularVelocity() { return data[3]; }
72
74 const arma::colvec& Encode() const { return data; }
75
77 static constexpr size_t dimension = 4;
78
79 private:
81 arma::colvec data;
82 };
83
87 class Action
88 {
89 public:
91 {
94 };
95 // To store the action.
97
98 // Track the size of the action space.
99 static const size_t size = 2;
100 };
101
117 CartPole(const size_t maxSteps = 200,
118 const double gravity = 9.8,
119 const double massCart = 1.0,
120 const double massPole = 0.1,
121 const double length = 0.5,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
125 const double xThreshold = 2.4,
126 const double doneReward = 1.0) :
127 maxSteps(maxSteps),
128 gravity(gravity),
129 massCart(massCart),
130 massPole(massPole),
131 totalMass(massCart + massPole),
132 length(length),
133 poleMassLength(massPole * length),
134 forceMag(forceMag),
135 tau(tau),
136 thetaThresholdRadians(thetaThresholdRadians),
137 xThreshold(xThreshold),
138 doneReward(doneReward),
139 stepsPerformed(0)
140 { /* Nothing to do here */ }
141
151 double Sample(const State& state,
152 const Action& action,
153 State& nextState)
154 {
155 // Update the number of steps performed.
156 stepsPerformed++;
157
158 // Calculate acceleration.
159 double force = action.action ? forceMag : -forceMag;
160 double cosTheta = std::cos(state.Angle());
161 double sinTheta = std::sin(state.Angle());
162 double temp = (force + poleMassLength * state.AngularVelocity() *
163 state.AngularVelocity() * sinTheta) / totalMass;
164 double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
165 (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
166 double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
167
168 // Update states.
169 nextState.Position() = state.Position() + tau * state.Velocity();
170 nextState.Velocity() = state.Velocity() + tau * xAcc;
171 nextState.Angle() = state.Angle() + tau * state.AngularVelocity();
172 nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc;
173
174 // Check if the episode has terminated.
175 bool done = IsTerminal(nextState);
176
177 // Do not reward agent if it failed.
178 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
179 return doneReward;
180
185 return 1.0;
186 }
187
196 double Sample(const State& state, const Action& action)
197 {
198 State nextState;
199 return Sample(state, action, nextState);
200 }
201
208 {
209 stepsPerformed = 0;
210 return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
211 }
212
219 bool IsTerminal(const State& state) const
220 {
221 if (maxSteps != 0 && stepsPerformed >= maxSteps)
222 {
223 Log::Info << "Episode terminated due to the maximum number of steps"
224 "being taken.";
225 return true;
226 }
227 else if (std::abs(state.Position()) > xThreshold ||
228 std::abs(state.Angle()) > thetaThresholdRadians)
229 {
230 Log::Info << "Episode terminated due to agent failing.";
231 return true;
232 }
233 return false;
234 }
235
237 size_t StepsPerformed() const { return stepsPerformed; }
238
240 size_t MaxSteps() const { return maxSteps; }
242 size_t& MaxSteps() { return maxSteps; }
243
244 private:
246 size_t maxSteps;
247
249 double gravity;
250
252 double massCart;
253
255 double massPole;
256
258 double totalMass;
259
261 double length;
262
264 double poleMassLength;
265
267 double forceMag;
268
270 double tau;
271
273 double thetaThresholdRadians;
274
276 double xThreshold;
277
279 double doneReward;
280
282 size_t stepsPerformed;
283};
284
285} // namespace rl
286} // namespace mlpack
287
288#endif
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Implementation of action of Cart Pole.
Definition: cart_pole.hpp:88
static const size_t size
Definition: cart_pole.hpp:99
Implementation of the state of Cart Pole.
Definition: cart_pole.hpp:34
double & Angle()
Modify the angle.
Definition: cart_pole.hpp:66
const arma::colvec & Encode() const
Encode the state to a column vector.
Definition: cart_pole.hpp:74
double & Velocity()
Modify the velocity.
Definition: cart_pole.hpp:61
double Velocity() const
Get the velocity.
Definition: cart_pole.hpp:59
State()
Construct a state instance.
Definition: cart_pole.hpp:39
double AngularVelocity() const
Get the angular velocity.
Definition: cart_pole.hpp:69
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: cart_pole.hpp:47
double & AngularVelocity()
Modify the angular velocity.
Definition: cart_pole.hpp:71
double & Position()
Modify the position.
Definition: cart_pole.hpp:56
double Position() const
Get the position.
Definition: cart_pole.hpp:54
static constexpr size_t dimension
Dimension of the encoded state.
Definition: cart_pole.hpp:77
double Angle() const
Get the angle.
Definition: cart_pole.hpp:64
arma::colvec & Data()
Modify the internal representation of the state.
Definition: cart_pole.hpp:51
Implementation of Cart Pole task.
Definition: cart_pole.hpp:27
CartPole(const size_t maxSteps=200, const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=1.0)
Construct a Cart Pole instance using the given constants.
Definition: cart_pole.hpp:117
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Cart Pole instance.
Definition: cart_pole.hpp:151
size_t & MaxSteps()
Set the maximum number of steps allowed.
Definition: cart_pole.hpp:242
size_t StepsPerformed() const
Get the number of steps performed.
Definition: cart_pole.hpp:237
bool IsTerminal(const State &state) const
This function checks if the cart has reached the terminal state.
Definition: cart_pole.hpp:219
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
Definition: cart_pole.hpp:207
size_t MaxSteps() const
Get the maximum number of steps allowed.
Definition: cart_pole.hpp:240
double Sample(const State &state, const Action &action)
Dynamics of Cart Pole.
Definition: cart_pole.hpp:196
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.