13#ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14#define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
49 State(
const arma::colvec& data) : data(data)
53 arma::colvec
Data()
const {
return data; }
55 arma::colvec&
Data() {
return data; }
68 double Angle(
const size_t i)
const {
return data[2 * i]; }
70 double&
Angle(
const size_t i) {
return data[2 * i]; }
78 const arma::colvec&
Encode()
const {
return data; }
124 const double m1 = 0.1,
125 const double m2 = 0.01,
126 const double l1 = 0.5,
127 const double l2 = 0.05,
128 const double gravity = 9.8,
129 const double massCart = 1.0,
130 const double forceMag = 10.0,
131 const double tau = 0.02,
132 const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
133 const double xThreshold = 2.4,
134 const double doneReward = 0.0) :
144 thetaThresholdRadians(thetaThresholdRadians),
145 xThreshold(xThreshold),
146 doneReward(doneReward),
166 arma::vec dydx(6, arma::fill::zeros);
170 Dsdt(state, action, dydx);
171 RK4(state, action, dydx, nextState);
177 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
201 double totalForce = action.
action ? forceMag : -forceMag;
202 double totalMass = massCart;
205 double sinTheta1 = std::sin(state.
Angle(1));
206 double sinTheta2 = std::sin(state.
Angle(2));
207 double cosTheta1 = std::cos(state.
Angle(1));
208 double cosTheta2 = std::cos(state.
Angle(2));
211 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
212 std::sin(2 * state.
Angle(1));
213 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
214 std::sin(2 * state.
Angle(2));
217 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
218 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
221 double xAcc = totalForce / totalMass;
225 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
226 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
243 const double hh = tau * 0.5;
244 const double h6 = tau / 6;
249 yt = state.
Data() + (hh * dydx);
254 yt = state.
Data() + (hh * dyt);
260 yt = state.
Data() + (tau * dym);
267 nextState.
Data() = state.
Data() + h6 * (dydx + dyt + 2 * dym);
281 return Sample(state, action, nextState);
292 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
303 if (maxSteps != 0 && stepsPerformed >= maxSteps)
305 Log::Info <<
"Episode terminated due to the maximum number of steps"
309 if (std::abs(state.
Position()) > xThreshold)
311 Log::Info <<
"Episode terminated due to cart crossing threshold";
314 if (std::abs(state.
Angle(1)) > thetaThresholdRadians ||
315 std::abs(state.
Angle(2)) > thetaThresholdRadians)
317 Log::Info <<
"Episode terminated due to pole falling";
360 double thetaThresholdRadians;
369 size_t stepsPerformed;
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Implementation of action of Double Pole Cart.
Implementation of the state of Double Pole Cart.
const arma::colvec & Encode() const
Encode the state to a vector..
arma::colvec Data() const
Get the internal representation of the state.
double & Velocity()
Modify the velocity of the cart.
double Velocity() const
Get the velocity of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
State()
Construct a state instance.
State(const arma::colvec &data)
Construct a state instance from given data.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double & Position()
Modify the position of the cart.
double Position() const
Get the position of the cart.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
static constexpr size_t dimension
Dimension of the encoded state.
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Double Pole Cart Balancing task.
DoublePoleCart(const size_t maxSteps=0, const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Double Pole Cart instance using the given constants.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method.
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.