ViennaCL - The Vienna Computing Library  1.7.0
Free open-source GPU-accelerated linear algebra and solver library.
viennacl/device_specific/templates/row_wise_reduction_template.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_ROW_WISE_REDUCTION_HPP
00002 #define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_ROW_WISE_REDUCTION_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2015, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00027 #include <vector>
00028 
00029 #include "viennacl/scheduler/forwards.h"
00030 
00031 #include "viennacl/device_specific/mapped_objects.hpp"
00032 #include "viennacl/device_specific/tree_parsing.hpp"
00033 #include "viennacl/device_specific/utils.hpp"
00034 
00035 #include "viennacl/device_specific/templates/template_base.hpp"
00036 #include "viennacl/device_specific/templates/utils.hpp"
00037 
00038 #include "viennacl/tools/tools.hpp"
00039 
00040 #include "viennacl/scheduler/io.hpp"
00041 
00042 namespace viennacl
00043 {
00044 namespace device_specific
00045 {
00046 
00047 struct row_wise_reduction_parameters : public template_base::parameters_type
00048 {
00049   row_wise_reduction_parameters(unsigned int _simd_width,
00050                                 unsigned int _local_size_0, unsigned int _local_size_1,
00051                                 unsigned int _num_groups_0, fetching_policy_type _fetch_policy): template_base::parameters_type(_simd_width, _local_size_0, _local_size_1, 1),
00052     num_groups_0(_num_groups_0), fetch_policy(_fetch_policy) { }
00053 
00054   unsigned int num_groups_0;
00055   fetching_policy_type fetch_policy;
00056 };
00057 
00058 class row_wise_reduction_template : public template_base_impl<row_wise_reduction_template, row_wise_reduction_parameters>
00059 {
00060 private:
00061   virtual int check_invalid_impl(viennacl::ocl::device const & /*dev*/) const
00062   {
00063     if (p_.fetch_policy==FETCH_FROM_LOCAL)
00064       return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
00065     return TEMPLATE_VALID;
00066   }
00067 
00068   unsigned int n_lmem_elements() const
00069   {
00070     return p_.local_size_0*(p_.local_size_1+1);
00071   }
00072 
00073   static void parse(scheduler::statement const & statement, std::vector<vcl_size_t> & idx, bool & is_trans, scheduler::lhs_rhs_element & matrix)
00074   {
00075     tree_parsing::traverse(statement, statement.root(), tree_parsing::filter(&utils::is_reduction, idx), false);
00076     is_trans = is_node_trans(statement.array(), idx[0], LHS_NODE_TYPE);
00077     matrix = lhs_most(statement.array(), idx[0]).lhs;
00078   }
00079 
00080   std::string generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mappings, unsigned int simd_width, bool is_trans, std::vector<mapped_row_wise_reduction*> const & exprs) const
00081   {
00082     using tools::to_string;
00083 
00084     unsigned int lsize0 = p_.local_size_0;
00085     unsigned int lsize1 = p_.local_size_1+1;
00086     std::string lsize1str = to_string(lsize1);
00087 
00088     utils::kernel_generation_stream stream;
00089 
00090     stream << " __attribute__((reqd_work_group_size(" << p_.local_size_0 << "," << p_.local_size_1 << ",1)))" << std::endl;
00091     generate_prototype(stream, kernel_prefix, "unsigned int M, unsigned int N,", mappings, statements);
00092     stream << "{" << std::endl;
00093     stream.inc_tab();
00094 
00095     tree_parsing::process(stream, PARENT_NODE_TYPE, "scalar", "#scalartype #namereg = *#pointer;", statements, mappings);
00096     tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#pointer += #start1 + #start2*#ld;", statements, mappings);
00097     tree_parsing::process(stream, PARENT_NODE_TYPE, "vector", "#pointer += #start;", statements, mappings);
00098 
00099     tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#ld *= #nldstride;", statements, mappings);
00100 
00101     for (std::vector<mapped_row_wise_reduction*>::const_iterator it = exprs.begin(); it != exprs.end(); ++it)
00102       stream << (*it)->process("__local #scalartype #name_buf[" + to_string(lsize0*lsize1) + "];") << std::endl;
00103 
00104     stream << "unsigned int lid0 = get_local_id(0);" << std::endl;
00105     stream << "unsigned int lid1 = get_local_id(1);" << std::endl;
00106     stream << "unsigned int upper_bound_0 = ( M +" << p_.local_size_0 - 1 << ")/" << p_.local_size_0 << "*" << p_.local_size_0 << ";" << std::endl;
00107     stream << "for(unsigned int r = get_global_id(0); r < upper_bound_0; r += get_global_size(0)){" << std::endl;
00108     stream.inc_tab();
00109 
00110     for (std::vector<mapped_row_wise_reduction*>::const_iterator it = exprs.begin(); it != exprs.end(); ++it)
00111       stream << (*it)->process("#scalartype #name_acc = " + neutral_element((*it)->root_op()) + ";") << std::endl;
00112 
00113     stream << "if (r < M)" << std::endl;
00114     stream << "{" << std::endl;
00115     stream.inc_tab();
00116 
00117     class loop_body : public loop_body_base
00118     {
00119     public:
00120       loop_body(std::vector<mapped_row_wise_reduction*> const & _exprs, bool _is_trans) : exprs(_exprs), is_trans(_is_trans){ }
00121       void operator()(utils::kernel_generation_stream & kernel_stream, unsigned int loop_simd_width) const
00122       {
00123         std::set<std::string> already_fetched;
00124         for (std::vector<mapped_row_wise_reduction*>::const_iterator it = exprs.begin(); it != exprs.end(); ++it)
00125         {
00126           if (is_trans)
00127             (*it)->process_recursive(kernel_stream, LHS_NODE_TYPE, "matrix_trans", utils::append_width("#scalartype",loop_simd_width) + " #namereg = " + vload(loop_simd_width, "c*#stride1", "#pointer + r*#ld")+";", already_fetched);
00128           else
00129             (*it)->process_recursive(kernel_stream, LHS_NODE_TYPE, "matrix", "#scalartype #namereg = #pointer[r*#stride1 + c*#ld];", already_fetched);
00130           (*it)->process_recursive(kernel_stream, RHS_NODE_TYPE, "vector", utils::append_width("#scalartype",loop_simd_width) + " #namereg = " + vload(loop_simd_width, "c*#stride", "#pointer")+";", already_fetched);
00131         }
00132 
00133 
00134         //Update accumulators
00135         std::vector<std::string> str(loop_simd_width);
00136         if (loop_simd_width==1)
00137           str[0] = "#namereg";
00138         else
00139           for (unsigned int a = 0; a < loop_simd_width; ++a)
00140             str[a] = append_simd_suffix("#namereg.s", a);
00141 
00142 
00143         for (unsigned int k = 0; k < exprs.size(); ++k)
00144         {
00145           for (unsigned int a = 0; a < loop_simd_width; ++a)
00146           {
00147             std::map<std::string, std::string> accessors;
00148             if (is_trans)
00149               accessors["matrix_trans"] = str[a];
00150             else
00151               accessors["matrix"] = str[a];
00152             accessors["vector"] = str[a];
00153             accessors["scalar"] = "#namereg";
00154             std::string value = exprs[k]->evaluate_recursive(LHS_NODE_TYPE, accessors);
00155             if (exprs[k]->root_node().op.type==scheduler::OPERATION_BINARY_MAT_VEC_PROD_TYPE)
00156               value+= "*" + exprs[k]->evaluate_recursive(RHS_NODE_TYPE, accessors);
00157 
00158             if (exprs[k]->is_index_reduction())
00159               compute_index_reduction(kernel_stream, exprs[k]->process("#name_acc"), "c*"+to_string(loop_simd_width) + to_string(a), exprs[k]->process("#name_acc_value"), value,exprs[k]->root_op());
00160             else
00161               compute_reduction(kernel_stream, exprs[k]->process("#name_acc"), value,exprs[k]->root_op());
00162           }
00163         }
00164       }
00165     private:
00166       std::vector<mapped_row_wise_reduction*> exprs;
00167       bool is_trans;
00168     };
00169 
00170     element_wise_loop_1D(stream, loop_body(exprs, is_trans), p_.fetch_policy, simd_width, "c", "N", "get_local_id(1)", "get_local_size(1)");
00171     stream.dec_tab();
00172     stream << "}" << std::endl;
00173 
00174     for (unsigned int k = 0; k < exprs.size(); ++k)
00175       stream << exprs[k]->process("#name_buf[lid0*" + lsize1str + "+ lid1] = #name_acc;") << std::endl;
00176 
00177     stream << "#pragma unroll" << std::endl;
00178     stream << "for(unsigned int stride = " << p_.local_size_1/2 << "; stride >0; stride /=2)" << std::endl;
00179     stream << "{" << std::endl;
00180     stream.inc_tab();
00181 
00182     stream << "barrier(CLK_LOCAL_MEM_FENCE); " << std::endl;
00183     stream <<  "if (lid1 < stride)" << std::endl;
00184     stream << "{" << std::endl;
00185     stream.inc_tab();
00186 
00187     for (unsigned int k = 0; k < exprs.size(); k++)
00188       if (exprs[k]->is_index_reduction())
00189         compute_index_reduction(stream, exprs[k]->process("#name_buf[lid0*" + lsize1str + " + lid1]"), exprs[k]->process("#name_buf[lid0*" + lsize1str + " + lid1 + stride]")
00190                                 , exprs[k]->process("#name_buf_value[lid0*" + lsize1str + " + lid1]"), exprs[k]->process("#name_buf_value[lid0*" + lsize1str + " + lid1 + stride]"),
00191                                 exprs[k]->root_op());
00192       else
00193         compute_reduction(stream,exprs[k]->process("#name_buf[lid0*" + lsize1str + " + lid1]"), exprs[k]->process("#name_buf[lid0*" + lsize1str + " + lid1 + stride]"), exprs[k]->root_op());
00194 
00195     stream.dec_tab();
00196     stream << "}" << std::endl;
00197 
00198     stream.dec_tab();
00199     stream << "}" << std::endl;
00200 
00201 
00202     stream <<  "if (lid1 == 0 && r < M)";
00203     stream << "{" << std::endl;
00204     stream.inc_tab();
00205     std::map<std::string, std::string> accessors;
00206     accessors["row_wise_reduction"] = "#name_buf[lid0*" + lsize1str + "]";
00207     accessors["vector"] = "#pointer[r*#stride]";
00208     tree_parsing::evaluate(stream, PARENT_NODE_TYPE, accessors, statements, mappings);
00209     stream.dec_tab();
00210     stream << "}" << std::endl;
00211 
00212 
00213     stream.dec_tab();
00214     stream << "}" << std::endl;
00215 
00216     stream.dec_tab();
00217     stream << "}" << std::endl;
00218 
00219     return stream.str();
00220   }
00221 
00222   std::vector<std::string> generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mappings) const
00223   {
00224     std::vector<mapped_row_wise_reduction*> exprs;
00225     bool is_trans = false;
00226     bool row_major = false;
00227     statements_container::data_type::const_iterator sit;
00228     std::vector<mapping_type>::const_iterator mit;
00229     for (mit = mappings.begin(), sit = statements.data().begin(); mit != mappings.end(); ++mit, ++sit)
00230     {
00231       std::vector<vcl_size_t> idx;
00232       scheduler::lhs_rhs_element A;
00233       parse(*sit, idx, is_trans, A);
00234       row_major = utils::call_on_matrix(A, utils::row_major_fun());
00235       for (unsigned int j = 0; j < idx.size(); ++j)
00236         exprs.push_back((mapped_row_wise_reduction*)(at(*mit, mapping_key(idx[j], PARENT_NODE_TYPE)).get()));
00237     }
00238     is_trans = is_trans ^ row_major;
00239 
00240     std::vector<std::string> res;
00241     if (is_trans && p_.simd_width>1)
00242     {
00243       res.push_back(generate_impl(kernel_prefix, statements, mappings, p_.simd_width, is_trans, exprs));
00244       res.push_back(generate_impl(kernel_prefix, statements, mappings, 1, is_trans, exprs));
00245     }
00246     else
00247       res.push_back(generate_impl(kernel_prefix, statements, mappings, 1, is_trans, exprs));
00248 
00249     return res;
00250   }
00251 public:
00252   row_wise_reduction_template(row_wise_reduction_template::parameters_type const & parameters, char A_trans, binding_policy_t binding_policy = BIND_ALL_UNIQUE) : template_base_impl<row_wise_reduction_template, row_wise_reduction_parameters>(parameters, binding_policy), A_trans_(A_trans){ }
00253 
00254   void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements)
00255   {
00256     std::vector<vcl_size_t> idx;
00257     scheduler::lhs_rhs_element A;
00258     bool is_trans;
00259     parse(statements.data().front(), idx, is_trans, A);
00260     bool row_major = utils::call_on_matrix(A, utils::row_major_fun());
00261 
00262     viennacl::ocl::kernel * kernel;
00263     if ((is_trans  ^ row_major)&& p_.simd_width>1)
00264     {
00265       if (has_strided_access(statements))
00266         kernel = &programs[1].program().get_kernel(kernel_prefix);
00267       else
00268         kernel = &programs[0].program().get_kernel(kernel_prefix);
00269     }
00270     else
00271       kernel = &programs[0].program().get_kernel(kernel_prefix);
00272 
00273     kernel->local_work_size(0,p_.local_size_0);
00274     kernel->local_work_size(1,p_.local_size_1);
00275     kernel->global_work_size(0,p_.local_size_0*p_.num_groups_0);
00276     kernel->global_work_size(1,p_.local_size_1);
00277 
00278     unsigned int current_arg = 0;
00279     if (is_trans)
00280     {
00281       kernel->arg(current_arg++, cl_uint(utils::call_on_matrix(A, utils::size2_fun())));
00282       kernel->arg(current_arg++, cl_uint(utils::call_on_matrix(A, utils::size1_fun())));
00283     }
00284     else
00285     {
00286       kernel->arg(current_arg++, cl_uint(utils::call_on_matrix(A, utils::size1_fun())));
00287       kernel->arg(current_arg++, cl_uint(utils::call_on_matrix(A, utils::size2_fun())));
00288     }
00289 
00290 
00291     set_arguments(statements, *kernel, current_arg);
00292     viennacl::ocl::enqueue(*kernel);
00293   }
00294 
00295 private:
00296   const char A_trans_;
00297 };
00298 
00299 }
00300 }
00301 
00302 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines