32 #ifndef SACADO_FAD_BLAS_HPP 33 #define SACADO_FAD_BLAS_HPP 35 #include "Teuchos_BLAS.hpp" 44 template <
typename OrdinalType,
typename FadType>
64 OrdinalType& n_dot, OrdinalType& inc_val,
69 OrdinalType lda, OrdinalType& n_dot,
70 OrdinalType& lda_val, OrdinalType& lda_dot,
77 OrdinalType& n_dot, OrdinalType& inc_val,
82 OrdinalType lda, OrdinalType& n_dot,
83 OrdinalType& lda_val, OrdinalType& lda_dot,
90 OrdinalType& n_dot, OrdinalType& inc_val,
95 OrdinalType lda, OrdinalType& n_dot,
96 OrdinalType& lda_val, OrdinalType& lda_dot,
99 void unpack(
FadType&
a, OrdinalType& n_dot, OrdinalType& final_n_dot,
103 OrdinalType& n_dot, OrdinalType& final_n_dot,
104 OrdinalType& inc_val, OrdinalType& inc_dot,
107 void unpack(
FadType*
A, OrdinalType m, OrdinalType n, OrdinalType lda,
108 OrdinalType& n_dot, OrdinalType& final_n_dot,
109 OrdinalType& lda_val, OrdinalType& lda_dot,
116 OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot,
120 OrdinalType lda, OrdinalType n_dot,
121 OrdinalType lda_val, OrdinalType lda_dot,
127 void free(
const FadType*
a, OrdinalType n, OrdinalType n_dot,
128 OrdinalType inc_val, OrdinalType inc_dot,
131 void free(
const FadType*
A, OrdinalType m, OrdinalType n,
132 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
139 OrdinalType inc_val, OrdinalType inc_dot,
143 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
150 OrdinalType inc_val, OrdinalType inc_dot,
154 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot,
162 OrdinalType n_dot)
const;
183 template <
typename OrdinalType,
typename FadType>
184 class BLAS :
public Teuchos::DefaultBLASImpl<OrdinalType,FadType> {
186 typedef typename Teuchos::ScalarTraits<FadType>::magnitudeType
MagnitudeType;
190 typedef Teuchos::DefaultBLASImpl<OrdinalType,FadType>
BLASType;
198 bool use_dynamic =
true, OrdinalType static_workspace_size = 0);
214 BLASType::ROTG(da,db,
c,s);
221 BLASType::ROT(n,
dx,incx,dy,incy,
c,s);
226 const OrdinalType incx)
const;
230 const OrdinalType incx,
FadType* y,
231 const OrdinalType incy)
const;
234 template <
typename alpha_type,
typename x_type>
235 void AXPY(
const OrdinalType n,
const alpha_type& alpha,
236 const x_type* x,
const OrdinalType incx,
FadType* y,
237 const OrdinalType incy)
const;
240 typename Teuchos::ScalarTraits<FadType>::magnitudeType
242 const OrdinalType incx)
const {
243 return BLASType::ASUM(n,x,incx);
247 template <
typename x_type,
typename y_type>
248 FadType DOT(
const OrdinalType n,
const x_type* x,
249 const OrdinalType incx,
const y_type* y,
250 const OrdinalType incy)
const;
254 const OrdinalType incx)
const;
258 const OrdinalType incx)
const {
259 return BLASType::IAMAX(n,x,incx);
272 template <
typename alpha_type,
typename A_type,
typename x_type,
274 void GEMV(Teuchos::ETransp trans,
const OrdinalType m,
276 const alpha_type& alpha,
const A_type*
A,
277 const OrdinalType lda,
const x_type* x,
278 const OrdinalType incx,
const beta_type& beta,
279 FadType* y,
const OrdinalType incy)
const;
286 template <
typename A_type>
287 void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans,
288 Teuchos::EDiag diag,
const OrdinalType n,
289 const A_type*
A,
const OrdinalType lda,
FadType* x,
290 const OrdinalType incx)
const;
293 template <
typename alpha_type,
typename x_type,
typename y_type>
294 void GER(
const OrdinalType m,
const OrdinalType n,
295 const alpha_type& alpha,
296 const x_type* x,
const OrdinalType incx,
297 const y_type* y,
const OrdinalType incy,
298 FadType*
A,
const OrdinalType lda)
const;
311 template <
typename alpha_type,
typename A_type,
typename B_type,
313 void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb,
314 const OrdinalType m,
const OrdinalType n,
const OrdinalType k,
315 const alpha_type& alpha,
const A_type*
A,
const OrdinalType lda,
316 const B_type*
B,
const OrdinalType ldb,
const beta_type& beta,
317 FadType*
C,
const OrdinalType ldc)
const;
325 template <
typename alpha_type,
typename A_type,
typename B_type,
327 void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo,
const OrdinalType m,
329 const alpha_type& alpha,
const A_type*
A,
330 const OrdinalType lda,
const B_type*
B,
331 const OrdinalType ldb,
333 const OrdinalType ldc)
const;
341 template <
typename alpha_type,
typename A_type>
342 void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo,
343 Teuchos::ETransp transa, Teuchos::EDiag diag,
344 const OrdinalType m,
const OrdinalType n,
345 const alpha_type& alpha,
346 const A_type*
A,
const OrdinalType lda,
347 FadType*
B,
const OrdinalType ldb)
const;
356 template <
typename alpha_type,
typename A_type>
357 void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo,
358 Teuchos::ETransp transa, Teuchos::EDiag diag,
359 const OrdinalType m,
const OrdinalType n,
360 const alpha_type& alpha,
361 const A_type*
A,
const OrdinalType lda,
362 FadType*
B,
const OrdinalType ldb)
const;
372 Teuchos::BLAS<OrdinalType, ValueType>
blas;
386 template <
typename x_type,
typename y_type>
387 void Fad_DOT(
const OrdinalType n,
389 const OrdinalType incx,
390 const OrdinalType n_x_dot,
392 const OrdinalType incx_dot,
394 const OrdinalType incy,
395 const OrdinalType n_y_dot,
397 const OrdinalType incy_dot,
399 const OrdinalType n_z_dot,
403 template <
typename alpha_type,
typename A_type,
typename x_type,
405 void Fad_GEMV(Teuchos::ETransp trans,
408 const alpha_type& alpha,
409 const OrdinalType n_alpha_dot,
410 const alpha_type* alpha_dot,
412 const OrdinalType lda,
413 const OrdinalType n_A_dot,
415 const OrdinalType lda_dot,
417 const OrdinalType incx,
418 const OrdinalType n_x_dot,
420 const OrdinalType incx_dot,
421 const beta_type& beta,
422 const OrdinalType n_beta_dot,
423 const beta_type* beta_dot,
425 const OrdinalType incy,
426 const OrdinalType n_y_dot,
428 const OrdinalType incy_dot,
429 const OrdinalType n_dot)
const;
432 template <
typename alpha_type,
typename x_type,
typename y_type>
433 void Fad_GER(
const OrdinalType m,
435 const alpha_type& alpha,
436 const OrdinalType n_alpha_dot,
437 const alpha_type* alpha_dot,
439 const OrdinalType incx,
440 const OrdinalType n_x_dot,
442 const OrdinalType incx_dot,
444 const OrdinalType incy,
445 const OrdinalType n_y_dot,
447 const OrdinalType incy_dot,
449 const OrdinalType lda,
450 const OrdinalType n_A_dot,
452 const OrdinalType lda_dot,
453 const OrdinalType n_dot)
const;
456 template <
typename alpha_type,
typename A_type,
typename B_type,
458 void Fad_GEMM(Teuchos::ETransp transa,
459 Teuchos::ETransp transb,
463 const alpha_type& alpha,
464 const OrdinalType n_alpha_dot,
465 const alpha_type* alpha_dot,
467 const OrdinalType lda,
468 const OrdinalType n_A_dot,
470 const OrdinalType lda_dot,
472 const OrdinalType ldb,
473 const OrdinalType n_B_dot,
475 const OrdinalType ldb_dot,
476 const beta_type& beta,
477 const OrdinalType n_beta_dot,
478 const beta_type* beta_dot,
480 const OrdinalType ldc,
481 const OrdinalType n_C_dot,
483 const OrdinalType ldc_dot,
484 const OrdinalType n_dot)
const;
487 template <
typename alpha_type,
typename A_type,
typename B_type,
493 const alpha_type& alpha,
494 const OrdinalType n_alpha_dot,
495 const alpha_type* alpha_dot,
497 const OrdinalType lda,
498 const OrdinalType n_A_dot,
500 const OrdinalType lda_dot,
502 const OrdinalType ldb,
503 const OrdinalType n_B_dot,
505 const OrdinalType ldb_dot,
506 const beta_type& beta,
507 const OrdinalType n_beta_dot,
508 const beta_type* beta_dot,
510 const OrdinalType ldc,
511 const OrdinalType n_C_dot,
513 const OrdinalType ldc_dot,
514 const OrdinalType n_dot)
const;
517 template <
typename alpha_type,
typename A_type>
520 Teuchos::ETransp transa,
524 const alpha_type& alpha,
525 const OrdinalType n_alpha_dot,
526 const alpha_type* alpha_dot,
528 const OrdinalType lda,
529 const OrdinalType n_A_dot,
531 const OrdinalType lda_dot,
533 const OrdinalType ldb,
534 const OrdinalType n_B_dot,
536 const OrdinalType ldb_dot,
537 const OrdinalType n_dot)
const;
540 template <
typename alpha_type,
typename A_type>
543 Teuchos::ETransp transa,
547 const alpha_type& alpha,
548 const OrdinalType n_alpha_dot,
549 const alpha_type* alpha_dot,
551 const OrdinalType lda,
552 const OrdinalType n_A_dot,
554 const OrdinalType lda_dot,
556 const OrdinalType ldb,
557 const OrdinalType n_B_dot,
559 const OrdinalType ldb_dot,
560 const OrdinalType n_dot)
const;
573 #define TEUCHOS_BLAS_FAD_SPEC(FADTYPE) \ 574 namespace Teuchos { \ 575 template <typename OrdinalType, typename ValueT> \ 576 class BLAS< OrdinalType, FADTYPE<ValueT> > : \ 577 public Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT> > { \ 579 BLAS(bool use_default_impl = true, bool use_dynamic = true, \ 580 OrdinalType static_workspace_size = 0) : \ 581 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT> >( \ 582 use_default_impl, use_dynamic,static_workspace_size) {} \ 583 BLAS(const BLAS& x) : \ 584 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT> >(x) {} \ 590 template <typename ValueT> \ 591 struct ArrayValueType< FADTYPE<ValueT> > { \ 592 typedef ValueT type; \ 596 #define TEUCHOS_BLAS_SFAD_SPEC(FADTYPE) \ 597 namespace Teuchos { \ 598 template <typename OrdinalType, typename ValueT, int Num> \ 599 class BLAS< OrdinalType, FADTYPE<ValueT,Num> > : \ 600 public Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT,Num> > { \ 602 BLAS(bool use_default_impl = true, bool use_dynamic = true, \ 603 OrdinalType static_workspace_size = 0) : \ 604 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT,Num> >( \ 605 use_default_impl, use_dynamic, static_workspace_size) {} \ 606 BLAS(const BLAS& x) : \ 607 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT,Num> >(x) {} \ 613 template <typename ValueT, int Num> \ 614 struct ArrayValueType< FADTYPE<ValueT,Num> > { \ 615 typedef ValueT type; \ 629 #undef TEUCHOS_BLAS_FAD_SPEC 630 #undef TEUCHOS_BLAS_SFAD_SPEC 634 #endif // SACADO_FAD_BLAS_HPP void free(const ScalarType *A, OrdinalType m, OrdinalType n, OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot, const ScalarType *val, const ScalarType *dot) const
ValueType * allocate_array(OrdinalType size) const
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A <- alpha*x*y'+A.
Teuchos::DefaultBLASImpl< OrdinalType, FadType > BLASType
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C <- alpha*op(A)*B+beta*C or C <- alpha*B*op(A)+beta*C where op...
Sacado::dummy< ValueType, scalar_type >::type ScalarType
bool use_default_impl
Use custom or default implementation.
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices...
std::vector< ValueType > gemv_Ax
Temporary array for GEMV.
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
virtual ~BLAS()
Destructor.
Fad specializations for Teuchos::BLAS wrappers.
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
void free(const ScalarType &a, OrdinalType n_dot, const ScalarType *dot) const
void ROTG(FadType *da, FadType *db, MagnitudeType *c, FadType *s) const
Computes a Givens plane rotation.
bool use_dynamic
Use dynamic memory allocation.
#define TEUCHOS_BLAS_SFAD_SPEC(FADTYPE)
ArrayTraits< OrdinalType, FadType > arrayTraits
ArrayTraits for packing/unpacking value/derivative arrays.
Sacado::ScalarType< FadType >::type scalar_type
#define TEUCHOS_BLAS_FAD_SPEC(FADTYPE)
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
void ROT(const OrdinalType n, FadType *dx, const OrdinalType incx, FadType *dy, const OrdinalType incy, MagnitudeType *c, FadType *s) const
Applies a Givens plane rotation.
ValueType * workspace_pointer
Pointer to current free entry in workspace.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y <- alpha*A*x+beta*y or y <- alpha*A'*x+beta*y where A i...
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
OrdinalType workspace_size
Size of static workspace.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
void free(const ScalarType *a, OrdinalType n, OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot, const ScalarType *val, const ScalarType *dot) const
std::vector< ValueType > gemm_AB
Temporary array for GEMM.
Teuchos::BLAS< OrdinalType, ValueType > blas
BLAS for values.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
void free_array(const ValueType *ptr, OrdinalType size) const
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*A*B+beta*C or C <- alpha*B*A+beta*C where A is an m ...
void free(const ValueType &a, OrdinalType n_dot, const ValueType *dot) const
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x <- A*x or x <- A'*x where A is a unit/non-unit n by n u...
Forward-mode AD class using dynamic memory allocation and expression templates.
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
Teuchos::ScalarTraits< FadType >::magnitudeType MagnitudeType
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
void free(const ValueType *A, OrdinalType m, OrdinalType n, OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot, const ValueType *val, const ValueType *dot) const
ValueType * workspace
Workspace for holding contiguous values/derivatives.
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
Teuchos::ScalarTraits< FadType >::magnitudeType ASUM(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Sum the absolute values of the entries of x.
Sacado::ScalarType< FadType >::type scalar_type
Forward-mode AD class using dynamic memory allocation and expression templates.
void free(const ValueType *a, OrdinalType n, OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot, const ValueType *val, const ValueType *dot) const
Sacado::ValueType< FadType >::type ValueType
OrdinalType IAMAX(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Return the index of the element of x with the maximum magnitude.
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y <- y+alpha*x.
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*op(A)*op(B)+beta*C where op(A) is either A or A'...
Base template specification for ValueType.
Sacado::ValueType< FadType >::type ValueType
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.