天天看點

darknet源碼分析(三)gemm實作

上一節中我們分析了darknet卷積層的前向操作過程,darknet對卷積計算的處理實際上是:

先使用im2col将input_channel*(height*width)的輸入特征圖(實際存儲是按照行存儲的,即是1*(input_channel*height*width)的一維數組)轉化成(input_channel*kernel_size*kernel_size)*(out_height*out_width)的特征矩陣,這裡同樣是按行存儲的。

之後通過gemm函數實作通用的矩陣乘法實作卷積計算,即讓卷積核矩陣*im2col後的輸入特征矩陣。其中卷積核的大小為(kenel_channel)*(inut_channel*kernel_size*kernel_size)

最後得到kernel_channel*(out_height*out_width)即卷積輸出的最終結果。

這一節,我們來看非常重要的gemm。

首先來看gemm.h,看其中的gemm_cpu,實作的是C=ALPHA*A*B + BETA*C操作,這裡的BETA*C表示加上偏置項。通過其傳入的參數,我們會發現gemm_cpu會對通過TA和TB變量來判斷是否對卷積核矩陣A和輸入的經過im2col轉換過的特征矩陣B進行轉置操作,是以真正實作的時候會根據不同的轉置方式來采取不同的gemm方法。由于我們的A與B的存儲方式都是一維數組,是以輸入參數還要包含A與B的行與列。gemm_cpu的輸入參數解釋如下

/*

**  功能:矩陣計算,完成C = ALPHA * A * B + BETA * C,

**       輸出的C也是按行存儲(所有行并成一行)

**  輸入: A,B,C   輸入矩陣(一維數組格式,按行存儲,所有行并成一行)

**        ALPHA   系數

**        BETA    系數

**        TA,TB   是否需要對A,B做轉置操作,是為1,否為0

**        M       A,C的行數

**        N       B,C的列數

**        K       A的列數,B的行數

**        lda     A的列數(不做轉置)或者行數(做轉置)

**        ldb     B的列數(不做轉置)或者行數(做轉置)

**        ldc     C的列數


*/

void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
           

現在來看gemm到底是怎麼實作的吧,來到gemm.c中,檢視gemm_cpu()函數

void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    //printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
    int i, j;
	/*首先完成BETA * C的操作*/
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            C[i*ldc + j] *= BETA;
        }
    }
	/*根據指定的TA和TB來選擇不同的矩陣乘法方法,如gemm_nn就代表A與B都不進行轉置操作的情況,gemm_tn代表對A進行轉置*/
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}
           

這裡做一下解釋:

當TA=0,TB=0時,我們進行的是C = ALPHA * A * B + BETA * C操作

當TA=1,TB=0時,進行的是C = ALPHA * A' * B + BETA * C操作

當TA=0,TB=1時,進行的是C = ALPHA * A * B' + BETA * C操作

當TA=1,TB=1時,進行的時C = ALPHA * A' * B' + BETA * C操作

也就是說我們的矩陣乘法一定要符合矩陣乘法準則

例如A = [1, 2, 3, 2, 2, 1], B = [2, 0, 1, 1, 2, 1],C=[0,0,0,0](因為是按行存儲的,是以都是一維數組,這個輸入是打死不變的。進行矩陣乘法時要将其想象成多元進行處理)我們最後的輸出C假設是2*2的矩陣。這樣進行矩陣乘法的A與B分别是2*3和3*2,這樣如果使用的是gemm_nn,A矩陣的實際大小就是2*3,同理B矩陣的大小是3*2。如果使用的是gemm_tn,也就是說A矩陣經過轉置變成了2*3的矩陣,這樣也就是說輸入的A矩陣為3*2的即[1,2;3,2;2,1],而B矩陣沒有經過轉置,是以B矩陣為[2,0;1,1;2,1], A'與B相乘最後算出的C為[9,5;8,3]。而當使用的是gemm_tt時,也就是A與B都進行了轉置,這樣A應該就是[1,2;3,2;2,1],B矩陣應為[2,0,1;1,2,1],這樣C應為[4,9;5,7]。下面來分析一下gemm_nn的具體實作方法。

/*
** M:C的行數,因為這裡A沒有做轉置換操作,是以這裡A的行數是M
** N:C的列數,因為這裡B也沒有做轉置操作,是以這裡B的列數是N
** K:這裡都沒有轉置,是以K代表A的列數,B的行數
** lda: 不轉置時該變量是A的列數,是以A的列數是lda
** ldb: 不轉置時該變量時B的行數,是以B的行數是ldb
** ldc: C的列數
*/
void gemm_nn(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    // 周遊A的每一行
    for(i = 0; i < M; ++i){
		// 周遊A的每一列
        for(k = 0; k < K; ++k){
			// 首先将A_PART*A的操作做完
            register float A_PART = ALPHA*A[i*lda+k];
			// 使用A_PART的第i行的所有數與B第k列的所有數做乘加操作
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}
           

與之形成對比,下面來看gemm_tt的操作過程

/*
** M:C的行數,因為這裡A做轉置換操作,是以這裡A的列數是M
** N:C的列數,因為這裡B做轉置操作,是以這裡B的行數是N
** K:這裡都轉置,是以K代表A的行數,B的列數
** lda: 不轉置時該變量是A的列數,是以A的行數是lda
** ldb: 不轉置時該變量時B的行數,是以B的列數是ldb
** ldc: C的列數
*/
void gemm_tt(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
	// 這裡周遊的是A的列數,C的行數
    for(i = 0; i < M; ++i){
		// 周遊的是C的列數,B的行數
        for(j = 0; j < N; ++j){
            register float sum = 0;
			// 周遊A的行和B的列
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
            }
            C[i*ldc+j] += sum;
        }
    }
}
           

可以看到這幾個gemm之間的差別就在于三個for循環的位置排序和最後對C中第i行第j列的數計算的位置放的位置不同。同時注意到在三個for循環之前都有一句#pragma omp parallel for 。這個是OpenMP中的一個指令,表示接下來的for循環将被多線程執行,這個語句要求幾個for循環之間不能夠有依賴。OpenMP 是 Open MultiProcessing 的縮寫。OpenMP 并不是一個簡單的函數庫,而是一個諸多編譯器支援的架構,或者說是協定吧,總之,不需要任何配置,你就可以在 Visual Studio 或者 gcc 中使用它了。OpenMP的設計們希望提供一種簡單的方式讓程式員不需要懂得建立和銷毀線程就能寫出多線程化程式。為此他們設計了一些pragma,指令和函數來讓編譯器能夠在合适的地方插入線程大多數的循環隻需要在for之前插入一個pragma就可以實作并行化。

我們這裡的三個循環操作之間沒依賴,是以完全可以使用openmp的pragma omp parallel for 來進行并行化使得操作更加快。