mlpack 3.4.2
naive_convolution.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP
14#define MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP
15
16#include <mlpack/prereqs.hpp>
17#include "border_modes.hpp"
18
19namespace mlpack {
20namespace ann {
21
34template<typename BorderMode = FullConvolution>
36{
37 public:
38 /*
39 * Perform a convolution (valid mode).
40 *
41 * @param input Input used to perform the convolution.
42 * @param filter Filter used to perform the convolution.
43 * @param output Output data that contains the results of the convolution.
44 * @param dW Stride of filter application in the x direction.
45 * @param dH Stride of filter application in the y direction.
46 * @param dilationW The dilation factor in x direction.
47 * @param dilationH The dilation factor in y direction.
48 */
49 template<typename eT, typename Border = BorderMode>
50 static typename std::enable_if<
51 std::is_same<Border, ValidConvolution>::value, void>::type
52 Convolution(const arma::Mat<eT>& input,
53 const arma::Mat<eT>& filter,
54 arma::Mat<eT>& output,
55 const size_t dW = 1,
56 const size_t dH = 1,
57 const size_t dilationW = 1,
58 const size_t dilationH = 1)
59 {
60 output = arma::zeros<arma::Mat<eT> >(
61 (input.n_rows - (filter.n_rows - 1) * dilationW - 1) / dW + 1,
62 (input.n_cols - (filter.n_cols - 1) * dilationH - 1) / dH + 1);
63
64 // It seems to be about 3.5 times faster to use pointers instead of
65 // filter(ki, kj) * input(leftInput + ki, topInput + kj) and output(i, j).
66 eT* outputPtr = output.memptr();
67
68 for (size_t j = 0; j < output.n_cols; ++j)
69 {
70 for (size_t i = 0; i < output.n_rows; ++i, outputPtr++)
71 {
72 const eT* kernelPtr = filter.memptr();
73 for (size_t kj = 0; kj < filter.n_cols; ++kj)
74 {
75 const eT* inputPtr = input.colptr(kj * dilationW + j * dW) + i * dH;
76 for (size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr,
77 inputPtr += dilationH)
78 *outputPtr += *kernelPtr * (*inputPtr);
79 }
80 }
81 }
82 }
83
84 /*
85 * Perform a convolution (full mode).
86 *
87 * @param input Input used to perform the convolution.
88 * @param filter Filter used to perform the convolution.
89 * @param output Output data that contains the results of the convolution.
90 * @param dW Stride of filter application in the x direction.
91 * @param dH Stride of filter application in the y direction.
92 * @param dilationW The dilation factor in x direction.
93 * @param dilationH The dilation factor in y direction.
94 */
95 template<typename eT, typename Border = BorderMode>
96 static typename std::enable_if<
97 std::is_same<Border, FullConvolution>::value, void>::type
98 Convolution(const arma::Mat<eT>& input,
99 const arma::Mat<eT>& filter,
100 arma::Mat<eT>& output,
101 const size_t dW = 1,
102 const size_t dH = 1,
103 const size_t dilationW = 1,
104 const size_t dilationH = 1)
105 {
106 size_t outputRows = (input.n_rows - 1) * dW + 2 * (filter.n_rows - 1)
107 * dilationW + 1;
108 size_t outputCols = (input.n_cols - 1) * dH + 2 * (filter.n_cols - 1)
109 * dilationH + 1;
110
111 for (size_t i = 0; i < dW; ++i)
112 {
113 if (((((i + outputRows - 2 * (filter.n_rows - 1) * dilationW - 1) % dW)
114 + dW) % dW) == i){
115 outputRows += i;
116 break;
117 }
118 }
119 for (size_t i = 0; i < dH; ++i)
120 {
121 if (((((i + outputCols - 2 * (filter.n_cols - 1) * dilationH - 1) % dH)
122 + dH) % dH) == i){
123 outputCols += i;
124 break;
125 }
126 }
127
128 // Pad filter and input to the working output shape.
129 arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
130 outputCols);
131 inputPadded.submat((filter.n_rows - 1) * dilationW, (filter.n_cols - 1)
132 * dilationH, (filter.n_rows - 1) * dilationW + input.n_rows - 1,
133 (filter.n_cols - 1) * dilationH + input.n_cols - 1) = input;
134
136 output, 1, 1, dilationW, dilationH);
137 }
138
139 /*
140 * Perform a convolution using 3rd order tensors.
141 *
142 * @param input Input used to perform the convolution.
143 * @param filter Filter used to perform the convolution.
144 * @param output Output data that contains the results of the convolution.
145 * @param dW Stride of filter application in the x direction.
146 * @param dH Stride of filter application in the y direction.
147 * @param dilationW The dilation factor in x direction.
148 * @param dilationH The dilation factor in y direction.
149 */
150 template<typename eT>
151 static void Convolution(const arma::Cube<eT>& input,
152 const arma::Cube<eT>& filter,
153 arma::Cube<eT>& output,
154 const size_t dW = 1,
155 const size_t dH = 1,
156 const size_t dilationW = 1,
157 const size_t dilationH = 1)
158 {
159 arma::Mat<eT> convOutput;
160 NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
161 convOutput, dW, dH, dilationW, dilationH);
162
163 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
164 input.n_slices);
165 output.slice(0) = convOutput;
166
167 for (size_t i = 1; i < input.n_slices; ++i)
168 {
169 NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
170 output.slice(i), dW, dH, dilationW, dilationH);
171 }
172 }
173
174 /*
175 * Perform a convolution using dense matrix as input and a 3rd order tensors
176 * as filter and output.
177 *
178 * @param input Input used to perform the convolution.
179 * @param filter Filter used to perform the convolution.
180 * @param output Output data that contains the results of the convolution.
181 * @param dW Stride of filter application in the x direction.
182 * @param dH Stride of filter application in the y direction.
183 * @param dilationW The dilation factor in x direction.
184 * @param dilationH The dilation factor in y direction.
185 */
186 template<typename eT>
187 static void Convolution(const arma::Mat<eT>& input,
188 const arma::Cube<eT>& filter,
189 arma::Cube<eT>& output,
190 const size_t dW = 1,
191 const size_t dH = 1,
192 const size_t dilationW = 1,
193 const size_t dilationH = 1)
194 {
195 arma::Mat<eT> convOutput;
196 NaiveConvolution<BorderMode>::Convolution(input, filter.slice(0),
197 convOutput, dW, dH, dilationW, dilationH);
198
199 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
200 filter.n_slices);
201 output.slice(0) = convOutput;
202
203 for (size_t i = 1; i < filter.n_slices; ++i)
204 {
205 NaiveConvolution<BorderMode>::Convolution(input, filter.slice(i),
206 output.slice(i), dW, dH, dilationW, dilationH);
207 }
208 }
209
210 /*
211 * Perform a convolution using a 3rd order tensors as input and output and a
212 * dense matrix as filter.
213 *
214 * @param input Input used to perform the convolution.
215 * @param filter Filter used to perform the convolution.
216 * @param output Output data that contains the results of the convolution.
217 * @param dW Stride of filter application in the x direction.
218 * @param dH Stride of filter application in the y direction.
219 * @param dilationW The dilation factor in x direction.
220 * @param dilationH The dilation factor in y direction.
221 */
222 template<typename eT>
223 static void Convolution(const arma::Cube<eT>& input,
224 const arma::Mat<eT>& filter,
225 arma::Cube<eT>& output,
226 const size_t dW = 1,
227 const size_t dH = 1,
228 const size_t dilationW = 1,
229 const size_t dilationH = 1)
230 {
231 arma::Mat<eT> convOutput;
232 NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter,
233 convOutput, dW, dH, dilationW, dilationH);
234
235 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
236 input.n_slices);
237 output.slice(0) = convOutput;
238
239 for (size_t i = 1; i < input.n_slices; ++i)
240 {
241 NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter,
242 output.slice(i), dW, dH, dilationW, dilationH);
243 }
244 }
245}; // class NaiveConvolution
246
247} // namespace ann
248} // namespace mlpack
249
250#endif
Computes the two-dimensional convolution.
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.