Sierra Toolkit  Version of the Day
AlgorithmRunnerTBB.cpp
1 /*------------------------------------------------------------------------*/
2 /* Copyright 2010 Sandia Corporation. */
3 /* Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive */
4 /* license for use of this work by or on behalf of the U.S. Government. */
5 /* Export of this program may require a license from the */
6 /* United States Government. */
7 /*------------------------------------------------------------------------*/
8 
9 
10 #include <stk_algsup/AlgorithmRunner.hpp>
11 
12 #ifdef STK_HAVE_TBB
13 
14 #include <tbb/task_scheduler_init.h>
15 #include <tbb/blocked_range.h>
16 #include <tbb/parallel_for.h>
17 #include <tbb/parallel_reduce.h>
18 #include <tbb/scalable_allocator.h>
19 #include <tbb/partitioner.h>
20 
21 #include <stk_mesh/base/BulkData.hpp>
22 #include <stk_mesh/base/Bucket.hpp>
23 
24 namespace stk_classic {
25 namespace {
26 
27 //----------------------------------------------------------------------
28 
29 struct RunTBB {
30  const mesh::Selector & selector ;
31  const mesh::PartVector & union_parts ;
32  const std::vector<mesh::Bucket*> & buckets ;
33  const AlgorithmInterface & alg ;
34 
35  void operator()(const tbb::blocked_range<int>& r) const;
36 
37  RunTBB( const mesh::Selector & arg_selector ,
38  const mesh::PartVector & arg_union_parts ,
39  const std::vector<mesh::Bucket*> & arg_buckets ,
40  const AlgorithmInterface & arg_alg );
41 
42  ~RunTBB();
43 };
44 
45 RunTBB::RunTBB(
46  const mesh::Selector & arg_selector ,
47  const mesh::PartVector & arg_union_parts ,
48  const std::vector<mesh::Bucket*> & arg_buckets ,
49  const AlgorithmInterface & arg_alg )
50  : selector( arg_selector ),
51  union_parts( arg_union_parts ),
52  buckets( arg_buckets ),
53  alg( arg_alg )
54 {}
55 
56 RunTBB::~RunTBB()
57 {
58 }
59 
60 void RunTBB::operator()( const tbb::blocked_range<int> & r ) const
61 {
62  for ( int i = r.begin() ; i < r.end() ; ++i ) {
63  alg.apply_one( selector , union_parts , * buckets[i] , NULL );
64  }
65 }
66 
67 struct RunTBBreduce {
68  const mesh::Selector & selector ;
69  const mesh::PartVector & union_parts ;
70  const std::vector<mesh::Bucket*> & buckets ;
71  const AlgorithmInterface & alg ;
72  void * reduce ;
73 
74  void operator()(const tbb::blocked_range<int>& r);
75 
76  void join( const RunTBBreduce & rhs ) const ;
77 
78  RunTBBreduce( const RunTBBreduce & rhs , tbb::split );
79 
80  RunTBBreduce( const mesh::Selector & arg_selector ,
81  const mesh::PartVector & arg_union_parts ,
82  const std::vector<mesh::Bucket*> & arg_buckets ,
83  const AlgorithmInterface & arg_alg ,
84  void * arg_reduce = NULL );
85 
86  ~RunTBBreduce();
87 };
88 
89 RunTBBreduce::RunTBBreduce( const RunTBBreduce & rhs , tbb::split )
90  : selector( rhs.selector ),
91  union_parts( rhs.union_parts ),
92  buckets( rhs.buckets ),
93  alg( rhs.alg ),
94  reduce( NULL )
95 {
96  if ( rhs.reduce ) {
97  reduce = malloc( alg.m_reduce_allocation_size ); //scalable_malloc ?
98  alg.init( reduce );
99  }
100 }
101 
102 RunTBBreduce::~RunTBBreduce()
103 {
104  if ( reduce ) { free( reduce ); /* scalable_free ? */}
105 }
106 
107 void RunTBBreduce::join( const RunTBBreduce & rhs ) const
108 {
109  alg.join( reduce , rhs.reduce );
110 }
111 
112 void RunTBBreduce::operator()( const tbb::blocked_range<int> & r )
113 {
114  for ( int i = r.begin() ; i < r.end() ; ++i ) {
115  alg.apply_one( selector , union_parts, * buckets[i] , reduce );
116  }
117 }
118 
119 RunTBBreduce::RunTBBreduce(
120  const mesh::Selector & arg_selector ,
121  const mesh::PartVector & arg_union_parts ,
122  const std::vector<mesh::Bucket*> & arg_buckets ,
123  const AlgorithmInterface & arg_alg ,
124  void * arg_reduce )
125  : selector( arg_selector ),
126  union_parts( arg_union_parts ),
127  buckets( arg_buckets ),
128  alg( arg_alg ),
129  reduce( arg_reduce )
130 {}
131 
132 //----------------------------------------------------------------------
133 
134 class AlgorithmRunnerTBB : public AlgorithmRunnerInterface {
135 public:
136  AlgorithmRunnerTBB(int nthreads)
137  : tbb_task_init_(NULL)
138  {
139  tbb_task_init_ = new tbb::task_scheduler_init(nthreads);
140  }
141 
142  ~AlgorithmRunnerTBB()
143  {
144  delete tbb_task_init_;
145  }
146 
147  void run_alg( const mesh::Selector & selector ,
148  const mesh::PartVector & union_parts ,
149  const std::vector< mesh::Bucket * > & buckets ,
150  const AlgorithmInterface & alg ,
151  void * reduce ) const ;
152 
153 private:
154  tbb::task_scheduler_init* tbb_task_init_;
155 };
156 
157 void AlgorithmRunnerTBB::run_alg(
158  const mesh::Selector & selector ,
159  const mesh::PartVector & union_parts ,
160  const std::vector< mesh::Bucket * > & buckets ,
161  const AlgorithmInterface & alg ,
162  void * reduce ) const
163 {
164  static tbb::affinity_partitioner ap;
165 
166  if ( reduce && ! alg.m_reduce_allocation_size ) {
167  std::string msg("AlgorithmRunnerTBB: ERROR reduce value with zero size");
168  throw std::invalid_argument(msg);
169  }
170 
171  if ( ! buckets.empty() ) {
172 
173  tbb::blocked_range<int> range( 0 , buckets.size() );
174 
175  if ( reduce ) {
176  RunTBBreduce tmp( selector , union_parts , buckets , alg, reduce );
177 
178  tbb::parallel_reduce( range , tmp , ap );
179  tmp.reduce = NULL ; /* Prevent the tbb::scalable_free( tmp.reduce ); */
180  }
181  else {
182  RunTBB tmp( selector , union_parts , buckets , alg );
183 
184  tbb::parallel_for( range, tmp , ap);
185  }
186  }
187 }
188 
189 } // namespace
190 
191 AlgorithmRunnerInterface * algorithm_runner_tbb( int nthreads )
192 {
193  static AlgorithmRunnerTBB runner(nthreads) ;
194 
195  return & runner ;
196 }
197 
198 } // namespace stk_classic
199 
200 #else
201 
202 namespace stk_classic {
203 
204 AlgorithmRunnerInterface * algorithm_runner_tbb( int nthreads )
205 {
206  return NULL ;
207 }
208 
209 } // namespace stk_classic
210 
211 #endif
212 
AlgorithmRunnerInterface * algorithm_runner_tbb(int nthreads)
Sierra Toolkit.
std::vector< Part *> PartVector
Collections of parts are frequently maintained as a vector of Part pointers.
Definition: Types.hpp:31