ROL
ROL_FDivergence.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ************************************************************************
3 //
4 // Rapid Optimization Library (ROL) Package
5 // Copyright (2014) Sandia Corporation
6 //
7 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8 // license for use of this work by or on behalf of the U.S. Government.
9 //
10 // Redistribution and use in source and binary forms, with or without
11 // modification, are permitted provided that the following conditions are
12 // met:
13 //
14 // 1. Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimer.
16 //
17 // 2. Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimer in the
19 // documentation and/or other materials provided with the distribution.
20 //
21 // 3. Neither the name of the Corporation nor the names of the
22 // contributors may be used to endorse or promote products derived from
23 // this software without specific prior written permission.
24 //
25 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 //
37 // Questions? Contact lead developers:
38 // Drew Kouri (dpkouri@sandia.gov) and
39 // Denis Ridzal (dridzal@sandia.gov)
40 //
41 // ************************************************************************
42 // @HEADER
43 
44 #ifndef ROL_FDIVERGENCE_HPP
45 #define ROL_FDIVERGENCE_HPP
46 
47 #include "ROL_RiskVector.hpp"
48 #include "ROL_RiskMeasure.hpp"
49 #include "ROL_Types.hpp"
50 
85 namespace ROL {
86 
87 template<class Real>
88 class FDivergence : public RiskMeasure<Real> {
89 private:
90 
91  Real thresh_;
92 
93  Teuchos::RCP<Vector<Real> > dualVector_;
94 
95  Real xlam_;
96  Real xmu_;
97  Real vlam_;
98  Real vmu_;
99 
100  Real valLam_;
101  Real valLam2_;
102  Real valMu_;
103  Real valMu2_;
104 
106 
107  void checkInputs(void) const {
108  Real zero(0);
109  TEUCHOS_TEST_FOR_EXCEPTION((thresh_ <= zero), std::invalid_argument,
110  ">>> ERROR (ROL::FDivergence): Threshold must be positive!");
111  }
112 
113 public:
118  FDivergence(const Real thresh) : RiskMeasure<Real>(), thresh_(thresh),
119  xlam_(0), xmu_(0), vlam_(0), vmu_(0), valLam_(0), valMu_(0),
120  firstReset_(true) {
121  checkInputs();
122  }
123 
132  FDivergence(Teuchos::ParameterList &parlist) : RiskMeasure<Real>(),
133  xlam_(0), xmu_(0), vlam_(0), vmu_(0), valLam_(0), valMu_(0),
134  firstReset_(true) {
135  Teuchos::ParameterList &list
136  = parlist.sublist("SOL").sublist("Risk Measure").sublist("F-Divergence");
137  thresh_ = list.get<Real>("Threshold");
138  checkInputs();
139  }
140 
148  virtual Real Fprimal(Real x, int deriv = 0) = 0;
149 
162  virtual Real Fdual(Real x, int deriv = 0) = 0;
163 
164  void reset(Teuchos::RCP<Vector<Real> > &x0, const Vector<Real> &x) {
166  xlam_ = Teuchos::dyn_cast<const RiskVector<Real> >(x).getStatistic(0);
167  xmu_ = Teuchos::dyn_cast<const RiskVector<Real> >(x).getStatistic(1);
168  if (firstReset_) {
169  dualVector_ = (x0->dual()).clone();
170  firstReset_ = false;
171  }
172  dualVector_->zero();
173  valLam_ = 0; valLam2_ = 0; valMu_ = 0; valMu2_ = 0;
174  }
175 
176  void reset(Teuchos::RCP<Vector<Real> > &x0, const Vector<Real> &x,
177  Teuchos::RCP<Vector<Real> > &v0, const Vector<Real> &v) {
178  reset(x0,x);
179  v0 = Teuchos::rcp_const_cast<Vector<Real> >(
180  Teuchos::dyn_cast<const RiskVector<Real> >(v).getVector());
181  vlam_ = Teuchos::dyn_cast<const RiskVector<Real> >(v).getStatistic(0);
182  vmu_ = Teuchos::dyn_cast<const RiskVector<Real> >(v).getStatistic(1);
183  }
184 
185  // Value update and get functions
186  void update(const Real val, const Real weight) {
187  Real r = Fdual((val-xmu_)/xlam_,0);
188  RiskMeasure<Real>::val_ += weight * r;
189  }
190 
192  Real val = RiskMeasure<Real>::val_, gval = 0;
193  sampler.sumAll(&val,&gval,1);
194  return xlam_*(thresh_ + gval) + xmu_;
195  }
196 
197  // Gradient update and get functions
198  void update(const Real val, const Vector<Real> &g, const Real weight) {
199  Real x = (val-xmu_)/xlam_;
200  Real r0 = Fdual(x,0), r1 = Fdual(x,1);
201 
202  RiskMeasure<Real>::val_ += weight * r0;
203  valLam_ -= weight * r1 * x;
204  valMu_ -= weight * r1;
205 
206  RiskMeasure<Real>::g_->axpy(weight*r1,g);
207  }
208 
210  RiskVector<Real> &gs = Teuchos::dyn_cast<RiskVector<Real> >(g);
211 
212  std::vector<Real> mygval(3), gval(3);
213  mygval[0] = RiskMeasure<Real>::val_;
214  mygval[1] = valLam_;
215  mygval[2] = valMu_;
216  sampler.sumAll(&mygval[0],&gval[0],3);
217 
218  std::vector<Real> stat(2);
219  stat[0] = thresh_ + gval[0] + gval[1];
220  stat[1] = (Real)1 + gval[2];
221  gs.setStatistic(stat);
222 
224  gs.setVector(*dualVector_);
225  }
226 
227  void update(const Real val, const Vector<Real> &g, const Real gv,
228  const Vector<Real> &hv, const Real weight) {
229  Real x = (val-xmu_)/xlam_;
230  Real r1 = Fdual(x,1), r2 = Fdual(x,2);
231  RiskMeasure<Real>::val_ += weight * r2 * x;
232  valLam_ += weight * r2 * x * x;
233  valLam2_ -= weight * r2 * gv * x;
234  valMu_ += weight * r2;
235  valMu2_ -= weight * r2 * gv;
236  RiskMeasure<Real>::hv_->axpy(weight * r2 * (gv - vmu_ - vlam_*x)/xlam_, g);
237  RiskMeasure<Real>::hv_->axpy(weight * r1, hv);
238  }
239 
241  RiskVector<Real> &hs = Teuchos::dyn_cast<RiskVector<Real> >(hv);
242 
243  std::vector<Real> myhval(5), hval(5);
244  myhval[0] = RiskMeasure<Real>::val_;
245  myhval[1] = valLam_;
246  myhval[2] = valLam2_;
247  myhval[3] = valMu_;
248  myhval[4] = valMu2_;
249  sampler.sumAll(&myhval[0],&hval[0],5);
250 
251  std::vector<Real> stat(2);
252  stat[0] = (vlam_ * hval[1] + vmu_ * hval[0] + hval[2])/xlam_;
253  stat[1] = (vlam_ * hval[0] + vmu_ * hval[3] + hval[4])/xlam_;
254  hs.setStatistic(stat);
255 
257  hs.setVector(*dualVector_);
258  }
259 };
260 
261 }
262 
263 #endif
void checkInputs(void) const
Contains definitions of custom data types in ROL.
void sumAll(Real *input, Real *output, int dim) const
void update(const Real val, const Vector< Real > &g, const Real weight)
Update internal risk measure storage for gradient computation.
void reset(Teuchos::RCP< Vector< Real > > &x0, const Vector< Real > &x)
Reset internal risk measure storage. Called for value and gradient computation.
Defines the linear algebra or vector space interface.
Definition: ROL_Vector.hpp:74
void getHessVec(Vector< Real > &hv, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.
void update(const Real val, const Real weight)
Update internal risk measure storage for value computation.
void reset(Teuchos::RCP< Vector< Real > > &x0, const Vector< Real > &x, Teuchos::RCP< Vector< Real > > &v0, const Vector< Real > &v)
Reset internal risk measure storage. Called for Hessian-times-a-vector computation.
void setVector(const Vector< Real > &vec)
void update(const Real val, const Vector< Real > &g, const Real gv, const Vector< Real > &hv, const Real weight)
Update internal risk measure storage for Hessian-time-a-vector computation.
Teuchos::RCP< Vector< Real > > dualVector_
void setStatistic(const Real stat)
FDivergence(Teuchos::ParameterList &parlist)
Constructor.
virtual Real Fdual(Real x, int deriv=0)=0
Implementation of the scalar dual F function.
FDivergence(const Real thresh)
Constructor.
virtual Real Fprimal(Real x, int deriv=0)=0
Implementation of the scalar primal F function.
Real getValue(SampleGenerator< Real > &sampler)
Return risk measure value.
Provides a general interface for the F-divergence distributionally robust expectation.
virtual void reset(Teuchos::RCP< Vector< Real > > &x0, const Vector< Real > &x)
Reset internal risk measure storage. Called for value and gradient computation.
void getGradient(Vector< Real > &g, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
Provides the interface to implement risk measures.