mlpack 3.4.2
fft_convolution.hpp
Go to the documentation of this file.
1
13#ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP
14#define MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP
15
16#include <mlpack/prereqs.hpp>
17#include "border_modes.hpp"
18
19namespace mlpack {
20namespace ann {
21
36template<typename BorderMode = FullConvolution, const bool padLastDim = false>
38{
39 public:
40 /*
41 * Perform a convolution through fft (valid mode). This method only supports
42 * input which is even on the last dimension. In case of an odd input width, a
43 * user can manually pad the input or specify the padLastDim parameter which
44 * takes care of the padding. The filter instead can have any size. When using
45 * the valid mode the filter has to be smaller than the input.
46 *
47 * @param input Input used to perform the convolution.
48 * @param filter Filter used to perform the convolution.
49 * @param output Output data that contains the results of the convolution.
50 */
51 template<typename eT, typename Border = BorderMode>
52 static typename std::enable_if<
53 std::is_same<Border, ValidConvolution>::value, void>::type
54 Convolution(const arma::Mat<eT>& input,
55 const arma::Mat<eT>& filter,
56 arma::Mat<eT>& output)
57 {
58 arma::Mat<eT> inputPadded = input;
59 arma::Mat<eT> filterPadded = filter;
60
61 if (padLastDim)
62 inputPadded.resize(inputPadded.n_rows, inputPadded.n_cols + 1);
63
64 // Pad filter and input to the output shape.
65 filterPadded.resize(inputPadded.n_rows, inputPadded.n_cols);
66
67 arma::Mat<eT> temp = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
68 filterPadded)));
69
70 // Extract the region of interest. We don't need to handle the padLastDim in
71 // a special way we just cut it out from the output matrix.
72 output = temp.submat(filter.n_rows - 1, filter.n_cols - 1,
73 input.n_rows - 1, input.n_cols - 1);
74 }
75
76 /*
77 * Perform a convolution through fft (full mode). This method only supports
78 * input which is even on the last dimension. In case of an odd input width, a
79 * user can manually pad the input or specify the padLastDim parameter which
80 * takes care of the padding. The filter instead can have any size.
81 *
82 * @param input Input used to perform the convolution.
83 * @param filter Filter used to perform the convolution.
84 * @param output Output data that contains the results of the convolution.
85 */
86 template<typename eT, typename Border = BorderMode>
87 static typename std::enable_if<
88 std::is_same<Border, FullConvolution>::value, void>::type
89 Convolution(const arma::Mat<eT>& input,
90 const arma::Mat<eT>& filter,
91 arma::Mat<eT>& output)
92 {
93 // In case of the full convolution outputRows and outputCols doesn't
94 // represent the true output size when the padLastDim parameter is set,
95 // instead it's the working size.
96 const size_t outputRows = input.n_rows + 2 * (filter.n_rows - 1);
97 size_t outputCols = input.n_cols + 2 * (filter.n_cols - 1);
98
99 if (padLastDim)
100 outputCols++;
101
102 // Pad filter and input to the working output shape.
103 arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
104 outputCols);
105 inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
106 filter.n_rows - 1 + input.n_rows - 1,
107 filter.n_cols - 1 + input.n_cols - 1) = input;
108
109 arma::Mat<eT> filterPadded = filter;
110 filterPadded.resize(outputRows, outputCols);
111
112 // Perform FFT and IFFT
113 arma::Mat<eT> temp = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
114 filterPadded)));
115
116 // Extract the region of interest. We don't need to handle the padLastDim
117 // parameter in a special way we just cut it out from the output matrix.
118 output = temp.submat(filter.n_rows - 1, filter.n_cols - 1,
119 2 * (filter.n_rows - 1) + input.n_rows - 1,
120 2 * (filter.n_cols - 1) + input.n_cols - 1);
121 }
122
123 /*
124 * Perform a convolution through fft using 3rd order tensors. This method only
125 * supports input which is even on the last dimension. In case of an odd input
126 * width, a user can manually pad the input or specify the padLastDim
127 * parameter which takes care of the padding. The filter instead can have any
128 * size.
129 *
130 * @param input Input used to perform the convolution.
131 * @param filter Filter used to perform the convolution.
132 * @param output Output data that contains the results of the convolution.
133 */
134 template<typename eT>
135 static void Convolution(const arma::Cube<eT>& input,
136 const arma::Cube<eT>& filter,
137 arma::Cube<eT>& output)
138 {
139 arma::Mat<eT> convOutput;
140 FFTConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
141 convOutput);
142
143 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
144 input.n_slices);
145 output.slice(0) = convOutput;
146
147 for (size_t i = 1; i < input.n_slices; ++i)
148 {
149 FFTConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
150 output.slice(i));
151 }
152 }
153
154 /*
155 * Perform a convolution through fft using dense matrix as input and a 3rd
156 * order tensors as filter and output. This method only supports input which
157 * is even on the last dimension. In case of an odd input width, a user can
158 * manually pad the input or specify the padLastDim parameter which takes care
159 * of the padding. The filter instead can have any size.
160 *
161 * @param input Input used to perform the convolution.
162 * @param filter Filter used to perform the convolution.
163 * @param output Output data that contains the results of the convolution.
164 */
165 template<typename eT>
166 static void Convolution(const arma::Mat<eT>& input,
167 const arma::Cube<eT>& filter,
168 arma::Cube<eT>& output)
169 {
170 arma::Mat<eT> convOutput;
171 FFTConvolution<BorderMode>::Convolution(input, filter.slice(0),
172 convOutput);
173
174 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
175 filter.n_slices);
176 output.slice(0) = convOutput;
177
178 for (size_t i = 1; i < filter.n_slices; ++i)
179 {
180 FFTConvolution<BorderMode>::Convolution(input, filter.slice(i),
181 output.slice(i));
182 }
183 }
184
185 /*
186 * Perform a convolution using a 3rd order tensors as input and output and a
187 * dense matrix as filter.
188 *
189 * @param input Input used to perform the convolution.
190 * @param filter Filter used to perform the convolution.
191 * @param output Output data that contains the results of the convolution.
192 */
193 template<typename eT>
194 static void Convolution(const arma::Cube<eT>& input,
195 const arma::Mat<eT>& filter,
196 arma::Cube<eT>& output)
197 {
198 arma::Mat<eT> convOutput;
199 FFTConvolution<BorderMode>::Convolution(input.slice(0), filter,
200 convOutput);
201
202 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
203 input.n_slices);
204 output.slice(0) = convOutput;
205
206 for (size_t i = 1; i < input.n_slices; ++i)
207 {
208 FFTConvolution<BorderMode>::Convolution(input.slice(i), filter,
209 output.slice(i));
210 }
211 }
212}; // class FFTConvolution
213
214} // namespace ann
215} // namespace mlpack
216
217#endif
Computes the two-dimensional convolution through fft.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.