14#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
15#define MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
50 State(
const arma::colvec& data) : data(data)
54 arma::colvec
Data()
const {
return data; }
56 arma::colvec&
Data() {
return data; }
69 double Angle(
const size_t i)
const {
return data[2 * i]; }
71 double&
Angle(
const size_t i) {
return data[2 * i]; }
79 const arma::colvec&
Encode()
const {
return data; }
117 const double m2 = 0.01,
118 const double l1 = 0.5,
119 const double l2 = 0.05,
120 const double gravity = 9.8,
121 const double massCart = 1.0,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 36 * 2 *
126 const double xThreshold = 2.4,
127 const double doneReward = 0.0,
128 const size_t maxSteps = 0) :
137 thetaThresholdRadians(thetaThresholdRadians),
138 xThreshold(xThreshold),
139 doneReward(doneReward),
160 arma::vec dydx(6, arma::fill::zeros);
164 Dsdt(state, action, dydx);
165 RK4(state, action, dydx, nextState);
171 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
195 double totalForce = action.
action[0];
196 double totalMass = massCart;
199 double sinTheta1 = std::sin(state.
Angle(1));
200 double sinTheta2 = std::sin(state.
Angle(2));
201 double cosTheta1 = std::cos(state.
Angle(1));
202 double cosTheta2 = std::cos(state.
Angle(2));
205 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
206 std::sin(2 * state.
Angle(1));
207 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
208 std::sin(2 * state.
Angle(2));
211 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
212 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
215 double xAcc = totalForce / totalMass;
219 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
220 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
237 const double hh = tau * 0.5;
238 const double h6 = tau / 6;
243 yt = state.
Data() + (hh * dydx);
248 yt = state.
Data() + (hh * dyt);
254 yt = state.
Data() + (tau * dym);
261 nextState.
Data() = state.
Data() + h6 * (dydx + dyt + 2 * dym);
275 return Sample(state, action, nextState);
286 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
297 if (maxSteps != 0 && stepsPerformed >= maxSteps)
299 Log::Info <<
"Episode terminated due to the maximum number of steps"
303 if (std::abs(state.
Position()) > xThreshold)
305 Log::Info <<
"Episode terminated due to cart crossing threshold";
308 if (std::abs(state.
Angle(1)) > thetaThresholdRadians ||
309 std::abs(state.
Angle(2)) > thetaThresholdRadians)
311 Log::Info <<
"Episode terminated due to pole falling";
351 double thetaThresholdRadians;
363 size_t stepsPerformed;
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Implementation of the state of Continuous Double Pole Cart.
const arma::colvec & Encode() const
Encode the state to a vector..
arma::colvec Data() const
Get the internal representation of the state.
double & Velocity()
Modify the velocity of the cart.
double Velocity() const
Get the velocity of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
State()
Construct a state instance.
State(const arma::colvec &data)
Construct a state instance from given data.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Position()
Modify the position of the cart.
double Position() const
Get the position of the cart.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
static constexpr size_t dimension
Dimension of the encoded state.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Continuous Double Pole Cart Balancing task.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Continuous Double Pole Cart instance.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
ContinuousDoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
double Sample(const State &state, const Action &action)
Dynamics of Continuous Double Pole Cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of action of Continuous Double Pole Cart.