Eigen  3.2.9
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
BlasUtil.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_BLASUTIL_H
11 #define EIGEN_BLASUTIL_H
12 
13 // This file contains many lightweight helper classes used to
14 // implement and control fast level 2 and level 3 BLAS-like routines.
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 // forward declarations
21 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
22 struct gebp_kernel;
23 
24 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
25 struct gemm_pack_rhs;
26 
27 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
28 struct gemm_pack_lhs;
29 
30 template<
31  typename Index,
32  typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
33  typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
34  int ResStorageOrder>
35 struct general_matrix_matrix_product;
36 
37 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized>
38 struct general_matrix_vector_product;
39 
40 
41 template<bool Conjugate> struct conj_if;
42 
43 template<> struct conj_if<true> {
44  template<typename T>
45  inline T operator()(const T& x) { return numext::conj(x); }
46  template<typename T>
47  inline T pconj(const T& x) { return internal::pconj(x); }
48 };
49 
50 template<> struct conj_if<false> {
51  template<typename T>
52  inline const T& operator()(const T& x) { return x; }
53  template<typename T>
54  inline const T& pconj(const T& x) { return x; }
55 };
56 
57 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
58 {
59  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
60  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
61 };
62 
63 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
64 {
65  typedef std::complex<RealScalar> Scalar;
66  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
67  { return c + pmul(x,y); }
68 
69  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
70  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
71 };
72 
73 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
74 {
75  typedef std::complex<RealScalar> Scalar;
76  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
77  { return c + pmul(x,y); }
78 
79  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
80  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
81 };
82 
83 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
84 {
85  typedef std::complex<RealScalar> Scalar;
86  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
87  { return c + pmul(x,y); }
88 
89  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
90  { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
91 };
92 
93 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
94 {
95  typedef std::complex<RealScalar> Scalar;
96  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
97  { return padd(c, pmul(x,y)); }
98  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
99  { return conj_if<Conj>()(x)*y; }
100 };
101 
102 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
103 {
104  typedef std::complex<RealScalar> Scalar;
105  EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
106  { return padd(c, pmul(x,y)); }
107  EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
108  { return x*conj_if<Conj>()(y); }
109 };
110 
111 template<typename From,typename To> struct get_factor {
112  static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
113 };
114 
115 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
116  static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
117 };
118 
119 // Lightweight helper class to access matrix coefficients.
120 // Yes, this is somehow redundant with Map<>, but this version is much much lighter,
121 // and so I hope better compilation performance (time and code quality).
122 template<typename Scalar, typename Index, int StorageOrder>
123 class blas_data_mapper
124 {
125  public:
126  blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
127  EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
128  { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
129  protected:
130  Scalar* EIGEN_RESTRICT m_data;
131  Index m_stride;
132 };
133 
134 // lightweight helper class to access matrix coefficients (const version)
135 template<typename Scalar, typename Index, int StorageOrder>
136 class const_blas_data_mapper
137 {
138  public:
139  const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
140  EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
141  { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
142  protected:
143  const Scalar* EIGEN_RESTRICT m_data;
144  Index m_stride;
145 };
146 
147 
148 /* Helper class to analyze the factors of a Product expression.
149  * In particular it allows to pop out operator-, scalar multiples,
150  * and conjugate */
151 template<typename XprType> struct blas_traits
152 {
153  typedef typename traits<XprType>::Scalar Scalar;
154  typedef const XprType& ExtractType;
155  typedef XprType _ExtractType;
156  enum {
157  IsComplex = NumTraits<Scalar>::IsComplex,
158  IsTransposed = false,
159  NeedToConjugate = false,
160  HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
161  && ( bool(XprType::IsVectorAtCompileTime)
162  || int(inner_stride_at_compile_time<XprType>::ret) == 1)
163  ) ? 1 : 0
164  };
165  typedef typename conditional<bool(HasUsableDirectAccess),
166  ExtractType,
167  typename _ExtractType::PlainObject
168  >::type DirectLinearAccessType;
169  static inline ExtractType extract(const XprType& x) { return x; }
170  static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
171 };
172 
173 // pop conjugate
174 template<typename Scalar, typename Xpr>
175 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, Xpr> >
176  : blas_traits<typename internal::remove_all<typename Xpr::Nested>::type>
177 {
178  typedef typename internal::remove_all<typename Xpr::Nested>::type NestedXpr;
179  typedef blas_traits<NestedXpr> Base;
180  typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, Xpr> XprType;
181  typedef typename Base::ExtractType ExtractType;
182 
183  enum {
184  IsComplex = NumTraits<Scalar>::IsComplex,
185  NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
186  };
187  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
188  static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
189 };
190 
191 // pop scalar multiple
192 template<typename Scalar, typename Xpr>
193 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, Xpr> >
194  : blas_traits<typename internal::remove_all<typename Xpr::Nested>::type>
195 {
196  typedef typename internal::remove_all<typename Xpr::Nested>::type NestedXpr;
197  typedef blas_traits<NestedXpr> Base;
198  typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, Xpr> XprType;
199  typedef typename Base::ExtractType ExtractType;
200  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
201  static inline Scalar extractScalarFactor(const XprType& x)
202  { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
203 };
204 
205 // pop opposite
206 template<typename Scalar, typename Xpr>
207 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, Xpr> >
208  : blas_traits<typename internal::remove_all<typename Xpr::Nested>::type>
209 {
210  typedef typename internal::remove_all<typename Xpr::Nested>::type NestedXpr;
211  typedef blas_traits<NestedXpr> Base;
212  typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, Xpr> XprType;
213  typedef typename Base::ExtractType ExtractType;
214  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
215  static inline Scalar extractScalarFactor(const XprType& x)
216  { return - Base::extractScalarFactor(x.nestedExpression()); }
217 };
218 
219 // pop/push transpose
220 template<typename Xpr>
221 struct blas_traits<Transpose<Xpr> >
222  : blas_traits<typename internal::remove_all<typename Xpr::Nested>::type>
223 {
224  typedef typename internal::remove_all<typename Xpr::Nested>::type NestedXpr;
225  typedef typename NestedXpr::Scalar Scalar;
226  typedef blas_traits<NestedXpr> Base;
227  typedef Transpose<Xpr> XprType;
228  typedef Transpose<const typename Base::_ExtractType> ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
229  typedef Transpose<const typename Base::_ExtractType> _ExtractType;
230  typedef typename conditional<bool(Base::HasUsableDirectAccess),
231  ExtractType,
232  typename ExtractType::PlainObject
233  >::type DirectLinearAccessType;
234  enum {
235  IsTransposed = Base::IsTransposed ? 0 : 1
236  };
237  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
238  static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
239 };
240 
241 template<typename T>
242 struct blas_traits<const T>
243  : blas_traits<T>
244 {};
245 
246 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
247 struct extract_data_selector {
248  static const typename T::Scalar* run(const T& m)
249  {
250  return blas_traits<T>::extract(m).data();
251  }
252 };
253 
254 template<typename T>
255 struct extract_data_selector<T,false> {
256  static typename T::Scalar* run(const T&) { return 0; }
257 };
258 
259 template<typename T> const typename T::Scalar* extract_data(const T& m)
260 {
261  return extract_data_selector<T>::run(m);
262 }
263 
264 } // end namespace internal
265 
266 } // end namespace Eigen
267 
268 #endif // EIGEN_BLASUTIL_H
Definition: Constants.h:266
const unsigned int DirectAccessBit
Definition: Constants.h:142