天天看點

MKL-DNN學習筆記 (七) Post-ops操作

MKL-DNN優化技術裡,有一個很重要的技術就是層融合(Layer Fusion)

所謂的Layer fusion, 就是把好幾層的計算合并成一層的操作裡,例如下圖左邊的計算一共包含了3層Convolution+Sum+ReLU, 每層之間都包含了輸入資料和輸出資料的讀寫。通過讀取觀察每層輸出的資料,我們也可以知道神經網絡每層到底做了些什麼,但是實際應用中我們隻關心神經網絡最開始的輸入資料和最終推理輸出的資料,這時候就可以把幾層的計算混合在一起計算,這樣就節省了大量的Memory I/O操作。

MKL-DNN學習筆記 (七) Post-ops操作

在MKL-DNN開發裡,我們可以通過設定每一層計算對象的Post-ops屬性來告訴MKL-DNN 這層計算完成後接下來會做哪些計算,這樣MKL-DNN會自動在計算中合并一些計算來提高運算效率。具體的post-ops的描述可以參考這裡的官方文檔,同時官方文檔的每一個計算Modules的描述文檔裡也講了這個Module支援哪些post-ops操作

在DRRN的計算量,大量的出現了BN+ReLU+Conv的計算順序,即BatchNorm+Scale算完接着就是ReLU計算

MKL-DNN學習筆記 (七) Post-ops操作

這時候就可以在定義BN計算的時候通過設定post-ops來幫忙帶點私活,計算ReLU,

我們把BatchNorm的代碼裡這句話

auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(bnrm_fwd_d, cpu_engine);
           

改為

mkldnn::post_ops po1;
	po1.append_eltwise(
		/* scale     = */ 1.f,
		/* alg kind  = */ mkldnn::algorithm::eltwise_relu,
		/* neg slope = */ 0.f,
		/* unused for relu */ 0.f);
	mkldnn::primitive_attr attr1;
	attr1.set_post_ops(po1);


	auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(bnrm_fwd_d, attr1, cpu_engine);
           

編譯運作

上篇文章的輸出

MKL-DNN學習筆記 (七) Post-ops操作

本篇代碼修改後的輸出

MKL-DNN學習筆記 (七) Post-ops操作

可以看到前3個不為0的項現在全為0了 :)

接下來試試Conv+Post-ops

把MKL-DNN學習筆記 (五) 實作Conv層的快速計算裡的這句

auto conv3_fast_prim_desc = convolution_forward::primitive_desc(conv3_fast_desc, cpu_engine);
           

改為

mkldnn::post_ops po;
	po.append_eltwise(
		/* scale     = */ 1.f,
		/* alg kind  = */ mkldnn::algorithm::eltwise_relu,
		/* neg slope = */ 0.f,
		/* unused for relu */ 0.f);
	mkldnn::primitive_attr attr;
	attr.set_post_ops(po);


	auto conv3_fast_prim_desc = convolution_forward::primitive_desc(conv3_fast_desc, attr, cpu_engine);
           

再修改一下image,weights,bias, 把卷積輸出的資料變複雜一下,運作...

沒有post-ops

MKL-DNN學習筆記 (七) Post-ops操作

有post-ops

MKL-DNN學習筆記 (七) Post-ops操作

搞定收工!!! 

最後代碼奉上,僅供參考

https://github.com/tisandman555/mkldnn_study/blob/master/post_ops.cpp

PS

對于BN層來說,還可以通過設定這個normalization_flags 來實作post-ops -> Relu的效果

normalization_flags flags = normalization_flags::use_global_stats | normalization_flags::use_scale_shift
           

指派時多或一個fuse_norm_relu的标志位

normalization_flags flags = normalization_flags::use_global_stats | normalization_flags::use_scale_shift | normalization_flags::fuse_norm_relu;
           

隻改這一個句,也可以達到相同的效果 

繼續閱讀