Descriptor for a layer normalization forward propagation primitive. More...
#include <dnnl.hpp>
Public Member Functions | |
desc (prop_kind prop_kind, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags) | |
Constructs a descriptor for layer normalization forward propagation primitive. More... | |
desc (prop_kind prop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags) | |
Constructs a descriptor for layer normalization forward propagation primitive. More... | |
Descriptor for a layer normalization forward propagation primitive.
|
inline |
Constructs a descriptor for layer normalization forward propagation primitive.
Inputs:
src
(dnnl::primitive_desc_base::src_desc(0
))mean
(dnnl::primitive_desc_base::src_desc(1
)), if dnnl::normalization_flags::use_global_stats bit-flag is set in flags
variance
(dnnl::primitive_desc_base::src_desc(2
)), if dnnl::normalization_flags::use_global_stats bit-flag is set in flags
scale_and_shift
(dnnl::primitive_desc_base::weights_desc(0
)), if dnnl::normalization_flags::use_scale_shift bit-flag is set in flags
Outputs:
dst
(dnnl::primitive_desc_base::dst_desc(0
))mean
(dnnl::primitive_desc_base::dst_desc(1
)), if dnnl::normalization_flags::use_global_stats bit-flag is not set in flags
and prop_kind
= dnnl::prop_kind::forward_trainingvariance
(dnnl::primitive_desc_base::dst_desc(2
)), if dnnl::normalization_flags::use_global_stats bit-flag is not set in flags
and prop_kind
= dnnl::prop_kind::forward_trainingprop_kind | Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference. |
data_desc | Source and destination memory descriptor. |
stat_desc | Statistics memory descriptors. |
epsilon | Layer normalization epsilon parameter. |
flags | Layer normalization flags (dnnl::normalization_flags). |
|
inline |
Constructs a descriptor for layer normalization forward propagation primitive.
Inputs:
src
(dnnl::primitive_desc_base::src_desc(0
))mean
(dnnl::primitive_desc_base::src_desc(1
)), if dnnl::normalization_flags::use_global_stats bit-flag is set in flags
variance
(dnnl::primitive_desc_base::src_desc(2
)), if dnnl::normalization_flags::use_global_stats bit-flag is set in flags
scale_and_shift
(dnnl::primitive_desc_base::weights_desc(0
)), if dnnl::normalization_flags::use_scale_shift bit-flag is set in flags
Outputs:
dst
(dnnl::primitive_desc_base::dst_desc(0
))mean
(dnnl::primitive_desc_base::dst_desc(1
)), if dnnl::normalization_flags::use_global_stats bit-flag is not set in flags
and prop_kind
= dnnl::prop_kind::forward_trainingvariance
(dnnl::primitive_desc_base::dst_desc(2
)), if dnnl::normalization_flags::use_global_stats bit-flag is not set in flags
and prop_kind
= dnnl::prop_kind::forward_trainingprop_kind | Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference. |
data_desc | Source and destination memory descriptor. |
epsilon | Layer normalization epsilon parameter. |
flags | Layer normalization flags (dnnl::normalization_flags). |