The example implements the Batch normalization u8 via the following operations: binary_sub(src, mean), binary_div(tmp_dst, variance), binary_mul(tmp_dst, scale), binary_add(tmp_dst, shift).
#include <algorithm>
#include <cmath>
#include <iostream>
#include <string>
#include <vector>
#include "dnnl.hpp"
#include "example_utils.hpp"
IC = 3,
IH = 150,
IW = 150;
std::vector<float> src_data(product(src_dims));
std::vector<float> mean_data(product(params_dims));
std::vector<float> variance_data(product(params_dims));
std::vector<float> scale_data(product(params_dims));
std::vector<float> shift_data(product(params_dims));
std::vector<float> oscale_data(product(params_dims));
std::generate(src_data.begin(), src_data.end(), []() {
static int i = 0;
return std::cos(i++ / 10.f);
});
std::generate(mean_data.begin(), mean_data.end(), []() {
static int i = 0;
return std::sin(i++ * 2.f);
});
std::generate(variance_data.begin(), variance_data.end(), []() {
static int i = 0;
return std::sin(i++ * 4.f);
});
std::generate(scale_data.begin(), scale_data.end(), []() {
static int i = 0;
return std::sin(i++ * 6.f);
});
std::generate(shift_data.begin(), shift_data.end(), []() {
static int i = 0;
return std::sin(i++ * 8.f);
});
std::generate(oscale_data.begin(), oscale_data.end(), []() { return 0.5; });
auto mean_md =
memory::desc(params_dims, dt::f32, tag::nhwc);
auto variance_md =
memory::desc(params_dims, dt::f32, tag::nhwc);
auto scale_md =
memory::desc(params_dims, dt::f32, tag::nhwc);
auto shift_md =
memory::desc(params_dims, dt::f32, tag::nhwc);
auto oscale_md =
memory::desc(params_dims, dt::f32, tag::nhwc);
write_to_dnnl_memory(src_data.data(), src_mem);
write_to_dnnl_memory(mean_data.data(), mean_mem);
write_to_dnnl_memory(variance_data.data(), variance_mem);
write_to_dnnl_memory(scale_data.data(), scale_mem);
write_to_dnnl_memory(shift_data.data(), shift_mem);
write_to_dnnl_memory(oscale_data.data(), oscale_mem);
auto binary_prim =
binary(binary_pd);
std::unordered_map<int, memory> binary_args;
binary_args.insert(
binary_args.insert(
binary_args.insert(
binary_args.insert(
binary_prim.execute(engine_stream, binary_args);
read_from_dnnl_memory(src_data.data(), src_mem);
}
int main(int argc, char **argv) {
return handle_example_errors(
bnorm_u8_via_binary_postops, parse_engine_kind(argc, argv));
}