天天看點

比對追蹤算法進行圖像重建

比對追蹤的過程已經在​​比對追蹤算法(MP)簡介​​中進行了簡單介紹,下面是使用Python進行圖像重建的實踐。

MP算法Python版

MP算法原理:

算法假定輸入信号與字典庫中的原子在結構上具有一定的相關性,這種相關性通過信号與原子庫中原子的内積表示,即内積越大,表示信号與字典庫中的這個原子的相關性越大,是以可以使用這個原子來近似表示這個信号。當然這種表示會有誤差,将表示誤差稱為信号殘差,用原信号減去這個原子,得到殘差,再通過計算相關性的方式從字典庫中選出一個原子表示這個殘差。疊代進行上述步驟,随着疊代次數的增加,信号殘差将越來越小,當滿足停止條件時終止疊代,得到一組原子,及殘差,将這組原子進行線性組合就能重構輸入信号。

MP算法的執行步驟如下:

輸入:字典矩陣A,信号向量y,稀疏度k.

輸出:x的k稀疏逼近x^.

初始化:生成字典矩陣A(這裡使用離散餘弦變換基DCT),殘差r0=y,索引集Λ0=∅,t=1.

循環執行步驟1-5:

  1. 找出殘差r和字典矩陣的列Ai積中最大值所對應的值p及腳标λ,即pt=maxi=1,⋯,N|<rt−1,Ai>|.
  2. 更新索引集Λt=Λt−1∪{λt},記錄找到的字典矩陣中的重建原子集合At=[At−1,Aλt].
  3. 更新稀疏向量x^t=x^t∪{pt}.
  4. 更新殘差rt=y−Atx^t,t=t+1.
  5. 判斷是否滿足t>k,若滿足,則疊代停止;若不滿足,則繼續執行步驟1.

Python代碼實作(針對二維圖像):

import numpy as np


def bmp(mtx, codebook, threshold):
    """
    :param mtx: 原始圖像(mxn)
    :param codebook: 字典(mxk)
    :param threshold: 非零元素個數的最大值
    :return: 稀疏編碼系數
   3 """
    n = mtx.shape[1] if len(mtx.shape) > 1 else 1  # 原始圖像mtx中向量的個數
    k = codebook.shape[1]  # 字典dictionary中向量的個數
    result = np.zeros((k, n))  # 系數矩陣result中行數等于dictionary中向量的個數,列數等于mtx中向量的個數

    for i in range(n):
        indices = []  # 記錄選中字典中原子的位置
        coefficients = []  # 存儲系數向量
        residual = mtx[:, i]
        for j in range(threshold):
            projection = np.dot(codebook.T, residual)
            # 擷取内積向量中元素絕對值的最大值
            max_value = projection.max()
            if abs(projection.min()) >= abs(projection.max()):
                max_value = projection.min()
            pos = np.where(projection == max_value)[0]
            indices.append(pos.tolist()[0])  # 隻存儲在字典中的列(因為計算過程中對codebook進行了轉置,是以這裡取第一個元素)
            coefficients.append(max_value)
            residual = mtx[:, i] - np.dot(codebook[:, indices[0: j + 1]], np.array(coefficients))
            if (residual ** 2).sum() < 1e-6:
                break
        for t, s in zip(indices, coefficients):
            result[t][i] = s
    return      

基于MP的圖像重建

對于較大的圖像,進行分塊處理,使用im2col和col2im函數進行圖像的分塊和分塊後的重建(參考:​​Python中如何實作im2col和col2im函數​​)。

這樣字典矩陣的行數就僅僅和分塊矩陣的大小有關,和原始圖像的大小沒有關系了。我們可以使用規模較小的字典矩陣表征較大的圖像。

Python代碼實作:

import numpy as np
from scipy import fftpack
import math
import mahotas as mh
import matplotlib.pyplot as plt
import mp.mpalg


def dct2(mtx):
    return fftpack.dct(fftpack.dct(mtx.T, norm='ortho').T, norm='ortho')


def idct2(mtx):
    return fftpack.idct(fftpack.idct(mtx.T, norm='ortho').T, norm='ortho')


def dctmtx(n):
    basis = np.zeros((n, n))
    for i in range(n):
        c = math.sqrt(2 / n) if i != 0 else math.sqrt(1 / n)
        for j in range(n):
            basis[i, j] = c * math.cos((j + 0.5) * math.pi * i / n)
    return basis


def im2col(mtx, block_size):
    mtx_shape = mtx.shape
    sx = mtx_shape[0] - block_size[0] + 1
    sy = mtx_shape[1] - block_size[1] + 1
    # 如果設A為m×n的,對于[p q]的塊劃分,最後矩陣的行數為p×q,列數為(m−p+1)×(n−q+1)。
    result = np.empty((block_size[0] * block_size[1], sx * sy))
    # 沿着行移動,是以先保持列(i)不動,沿着行(j)走
    for i in range(sy):
        for j in range(sx):
            result[:, i * sx + j] = mtx[j:j + block_size[0], i:i + block_size[1]].ravel(order='F')
    return result


def col2im(mtx, image_size, block_size):
    p, q = block_size
    sx = image_size[0] - p + 1
    sy = image_size[1] - q + 1
    result = np.zeros(image_size)
    weight = np.zeros(image_size)  # weight記錄每個單元格的數字重複加了多少遍
    col = 0
    # 沿着行移動,是以先保持列(i)不動,沿着行(j)走
    for i in range(sy):
        for j in range(sx):
            result[j:j + p, i:i + q] += mtx[:, col].reshape(block_size, order='F')
            weight[j:j + p, i:i + q] += np.ones(block_size)
            col += 1
    return result / weight


def sparse_encode(image, block_size, codebook, threshold):
    blocks = im2col(image, block_size)
    return mp.mpalg.bmp(blocks, codebook, threshold)


def sparse_decode(coefficients, codebook, image_size, block_size):
    blocks = np.dot(codebook, coefficients)
    return col2im(blocks, image_size, block_size)


if __name__ == '__main__':
    image = mh.imread('Lenna.jpg')
    image = mh.colors.rgb2gray(image)
    image_size = image.shape
    block_size = (8, 8)

    codebook = dctmtx(block_size[0] * block_size[1])
    threshold = 30
    coefficients = sparse_encode(image, block_size, codebook, threshold)
    reconstructed = sparse_decode(coefficients, codebook, image_size, block_size)

    plt.gray()
    plt.subplot(121)
    plt.title('原始圖像')
    plt.imshow(image)
    plt.subplot(122)
    plt.title('稀疏重建')
    plt.imshow(reconstructed)
    plt.show()      

下面是分别設定threshold為10,20和30的運作結果:

比對追蹤算法進行圖像重建

稀疏系數設定為10的重建結果

比對追蹤算法進行圖像重建

稀疏系數設定為20的重建結果

比對追蹤算法進行圖像重建

參考資料

  1. ​​比對追蹤算法原理(GitHub)​​
  2. ​​比對追蹤算法原理(簡書)​​