天天看點

基于圖的圖像分割

基于圖的圖像分割 Effective graph-based image segmentation

  • ​​一、前言​​
  • ​​二、算法理論​​
  • ​​2.1 建構圖​​
  • ​​2.2 分割圖​​
  • ​​2.3 算法的實作​​
  • ​​2.4 幾個分割概念​​
  • ​​2.5 算法步驟​​
  • ​​三、代碼實作​​
  • ​​numpy實作​​

一、前言

最近一段時間在複現基于區域的對比度方法(region-based contrast 簡稱RC)的顯著性檢測。​​​​​ 其中,遇到了問題。主要是用到了基于圖的圖像分割。 顯著性檢測RC算法:其前期工作就是利用Graph-Based Image Segmentation的分割算法。主要涉及了圖網絡的一些知識。​

Graph-Based Image Segmentation是2004年由Felzenszwalb發表在IJCV上的一篇文章,主要介紹了一種基于圖表示(graph-based)的圖像分割方法。圖像分割(Image Segmentation)的主要目的也就是将圖像(image)分割成若幹個特定的、具有獨特性質的區域(region),然後從中提取出感興趣的目标(object)。而圖像區域之間的邊界定義是圖像分割算法的關鍵,論文給出了一種在圖表示(graph-based)下圖像區域之間邊界的定義的判斷标準(predicate),其分割算法就是利用這個判斷标準(predicate)使用貪心選擇(greedy decision)來産生分割(segmentation)。

該算法在時間效率上,基本上與圖像(Image)的圖(Graph)表示的邊(edge)數量成線性關系,而圖像的圖表示的邊與像素點成正比,也就說圖像分割的時間效率與圖像的像素點個數成線性關系。這個算法有一個非常重要的特性,它能保持低變化(low-variability)區域(region)的細節,同時能夠忽略高變化(high-variability)區域(region)的細節。這個性質很特别也很重要,對圖像有一個很好的分割效果(能夠找出視覺上一緻的區域,簡單講就是高變化區域有一個很好聚合(grouping),能夠把它們分在同一個區域),這也是為什麼那麼多人引用該論文的原因吧。​

無論在分割領域還是顯著性檢測上,都是能夠捕捉視覺上重要的區域(perceptually important regions)。舉個栗子:在下圖1左側有個紅三角,下圖2左側有個更大的紅三角,我們可以認為圖2的紅三角更顯眼(更加的靠左側),

基于圖的圖像分割

二、算法理論

該論文主要有兩個關鍵點:

  1. 圖像(image)的圖(graph)表示;
  2. 最小生成樹(Minimun Spanning Tree)。
2.1 建構圖

圖像(image)的圖表示是指将圖像(image)表達成圖論中的圖(graph)。具體說來就是,把圖像中的每一個像素點看成一個頂點 (node或vertex),像素點之間的關系對(可以自己定義其具體關系,一般來說是指相鄰關系)構成圖的一條邊 ,這樣就建構好了一個圖 。

相鄰的兩個像素點像素顔色值的差異構成邊的權值。其中權值越小,表示像素點之間的相似度就越高,反之,相似度就越低。圖每條邊的權值是基于像素點之間的關系,可以是像素點之間的灰階值差,也可以是像素點(RGB)之間的距離:

灰階值素點之間距離:

像素點(RGB)之間的距離:

2.2 分割圖

将圖像表達成圖之後,接下來就是要如何分割這個圖。将每個節點(像素點)看成單一的區域,然後進行合并。使用最小生成樹方法合并像素點,然後構成一個個區域。大緻意思就是講圖(Graph)簡化,相似的區域在一個分支(Branch)上面(有一條最邊連接配接),大大減少了圖的邊數。

圖(Graph)分割是将

2.3 算法的實作
  1. 分割區域(Component)的内部差(internal difference)。可以先假定圖G已經簡化成了最小生成樹 MST,一個分割區域C 包含若幹個頂點 ,頂點之間通過最小生成樹的邊連接配接。這個内部差就是指分割區域C中包含的最大邊的權值。
  2. 基于圖的圖像分割
  3. 分割區域(Component)之間的差别(Difference),是指兩個分割區域之間頂點互相連接配接的最小邊的權值。
  4. 基于圖的圖像分割
  5. 如果兩個分割部分之間沒有邊連接配接,定義。分割區域的差别可以有很多種定義的方式,可以選擇中位置,或者其他的分位點(quantile,中位置是0.5分位點),但是選取其他的方式将會使得這個問題成為一個NP-hard問題。
  6. 分割區域(Component)邊界的一種判斷标準(predicate)。判斷兩個分割區域之間是否有明顯的邊界,主要是判斷兩個分割部分之間的差别Dif相對于和中較小的那個值MInt的大小,這裡引入了一個門檻值函數τ 來控制兩者之間的內插補點。下面給出這個判斷标準的定義:
  7. 基于圖的圖像分割
  8. 其中,是指最小的分割内部差,其定義如下:
  9. 基于圖的圖像分割
  10. 門檻值函數主要是為了更好的控制分割區域邊界的定義。比較直覺的了解,小分割區域的邊界定義要強于大分割區域,否則可以将小分割區域繼續合并形成大區域。在這裡給出的門檻值函數與區域的大小有關。
  11. 基于圖的圖像分割
  12. |C|是指分割部分頂點的個數(或者像素點個數),k是一個參數,可以根據不同的需求(主要根據圖像的尺寸)進行調節。
2.4 幾個分割概念
  1. 如果一個分割S,存在圖(Graph)的分割區域之間,沒有明顯的邊界,那麼就說這個分割S“太精細”(too fine)。也就是說它們之間沒有明顯的分界線,硬要把它們分割開來的話,有點過頭,也就是說分得太細。
  2. 如果一個分割S,存在一個合适的調整(refinement)S’使得S不是”太精細“,那麼就說這個分割S”太粗糙“(too coarse)。簡單來講就是,分割程度的還不夠(粒度還比較大),可以繼續分割,這樣剛開始的那個分割就是”太粗糙“(too coarse)了。

對于一個圖graph來說,都存在一個分割S,既不是”太精細“(too fine)也不是”太粗糙“(too coarse)。

2.5 算法步驟
  1. 對于圖G的所有邊,按照權值進行排序(升序)
  2. S[0]是一個原始分割,相當于每個頂點當做是一個分割區域
  3. q = 1,2,…,m 重複3的操作(m為邊的條數,也就是每次處理一條邊)
  4. 根據上次的建構。選擇一條邊o[q](vi,vj),如果vi和vj在分割的互不相交的區域中,比較這條邊的權值與這兩個分割區域之間的最小分割内部差MInt,如果o[q](vi,vj) < MInt,那麼合并這兩個區域,其他區域不變;如果否,什麼都不做。
  5. 最後得到的就是所求的分割 S = S[m]

三、代碼實作

C++實作代碼請檢視:http://cs.brown.edu/people/pfelzens/segment/

class Node:
    def __init__(self, parent, rank=0, size=1):
        self.parent = parent
        self.rank = rank
        self.size = size

    def __repr__(self):
        return '(parent=%s, rank=%s, size=%s)' % (self.parent, self.rank, self.size)

class Forest:
    def __init__(self, num_nodes):
        self.nodes = [Node(i) for i in range(num_nodes)]
        self.num_sets = num_nodes

    def size_of(self, i):
        return self.nodes[i].size

    def find(self, n):
        temp = n
        while temp != self.nodes[temp].parent:
            temp = self.nodes[temp].parent

        self.nodes[n].parent = temp
        return temp

    def merge(self, a, b):
        if self.nodes[a].rank > self.nodes[b].rank:
            self.nodes[b].parent = a
            self.nodes[a].size = self.nodes[a].size + self.nodes[b].size
        else:
            self.nodes[a].parent = b
            self.nodes[b].size = self.nodes[b].size + self.nodes[a].size

            if self.nodes[a].rank == self.nodes[b].rank:
                self.nodes[b].rank = self.nodes[b].rank + 1

        self.num_sets = self.num_sets - 1

    def print_nodes(self):
        for node in self.nodes:
            print(node)

def create_edge(img, width, x, y, x1, y1, diff):
    vertex_id = lambda x, y: y * width + x
    w = diff(img, x, y, x1, y1)
    return (vertex_id(x, y), vertex_id(x1, y1), w)

def build_graph(img, width, height, diff, neighborhood_8=False):
    graph_edges = []
    for y in range(height):
        for x in range(width):
            if x > 0:
                graph_edges.append(create_edge(img, width, x, y, x-1, y, diff))
            if y > 0:
                graph_edges.append(create_edge(img, width, x, y, x, y-1, diff))
            if neighborhood_8:
                if x > 0 and y > 0:
                    graph_edges.append(create_edge(img, width, x, y, x-1, y-1, diff))
                if x > 0 and y < height-1:
                    graph_edges.append(create_edge(img, width, x, y, x-1, y+1, diff))
    return graph_edges

def remove_small_components(forest, graph, min_size):
    for edge in graph:
        a = forest.find(edge[0])
        b = forest.find(edge[1])

        if a != b and (forest.size_of(a) < min_size or forest.size_of(b) < min_size):
            forest.merge(a, b)
    return  forest

# segment_graph(graph_edges, size[0]*size[1], K, min_comp_size, threshold)
def segment_graph(graph_edges, num_nodes, const, min_size, threshold_func):
    # Step 1: initialization
    # [(parent,rank,size) for i in range(num_nodes)]
    forest = Forest(num_nodes)

    weight = lambda edge: edge[2]
    sorted_graph = sorted(graph_edges, key=weight)
    threshold = [ threshold_func(1, const) for _ in range(num_nodes) ]

    # Step 2: merging
    for edge in sorted_graph:
        parent_a = forest.find(edge[0])
        parent_b = forest.find(edge[1])
        a_condition = weight(edge) <= threshold[parent_a]
        b_condition = weight(edge) <= threshold[parent_b]

        if parent_a != parent_b and a_condition and b_condition:
            forest.merge(parent_a, parent_b)
            a = forest.find(parent_a)
            threshold[a] = weight(edge) + threshold_func(forest.nodes[a].size, const)
    return remove_small_components(forest, sorted_graph, min_size)      
numpy實作
def segment_graph(height_width, num, edges, c=20.0, min_size=200):
    u_array = np.zeros((height_width, 3), dtype=np.int32)
    u_array[:, 1] = np.array(range(height_width), dtype=np.int32)
    u_array[:, 2] = np.ones(height_width, dtype=np.int32)
    thresholds_copy = np.full(height_width,c,dtype=np.float32)
    loop_range = range(num)

    for i in loop_range:
        edge = edges[i]
        a = edge['a']
        while a!=u_array[a,1]:
            a =edge['a']= u_array[a, 1]
        b = edge['b']
        while b!=u_array[b,1]:
            b =edge['b']= u_array[b, 1]
        if a != b:
            if edge['w'] <= thresholds_copy[a] and edge['w'] <= thresholds_copy[b]:
                if (u_array[a, 0] > u_array[b, 0]):
                    u_array[b, 1] = a
                    u_array[a, 2] += u_array[b, 2]
                else:
                    u_array[a, 1] = b
                    u_array[b, 2] += u_array[a, 2]
                    if u_array[a, 0] == u_array[b, 0]:
                        u_array[b, 0] += 1
                while a != u_array[edge['a'], 1]:
                    a = edge['a'] = u_array[edge['a'], 1]
                thresholds_copy[a] = edge['w'] + c/u_array[a,2]
    for i in loop_range:
        while (edges[i]['a'] != u_array[edges[i]['a'],1]):
            edges[i]['a'] = u_array[edges[i]['a'],1]
        while (edges[i]['b'] != u_array[edges[i]['b'],1]):
            edges[i]['b'] = u_array[edges[i]['b'],1]
        if ((edges[i]['a'] != edges[i]['b']) and ((u_array[edges[i]['a'],2] < min_size) or (u_array[edges[i]['b'],2] < min_size))):
            if (u_array[edges[i]['a'], 0] > u_array[edges[i]['b'], 0]):
                u_array[edges[i]['b'], 1] = edges[i]['a']
                u_array[edges[i]['a'], 2] += u_array[edges[i]['b'], 2]
            else:
                u_array[edges[i]['a'], 1] = edges[i]['b']
                u_array[edges[i]['b'], 2] += u_array[edges[i]['a'], 2]
                if u_array[edges[i]['a'], 0] == u_array[edges[i]['b'], 0]:
                    u_array[edges[i]['b'], 0] += 1
    return u_array


# ===========================SegmentImage==========================================
# 像素間的差異度量
def diff(img3f, x1, y1, x2, y2):
    p1 = img3f[y1, x1]
    p2 = img3f[y2, x2]
    return np.sqrt(np.sum(np.power(p1 - p2, 2)))


def SegmentImage(smImg3f, c=20.0, min_size=200):
    height, width = smImg3f.shape[:2]
    edges = np.zeros((height-1)*(width-1)*4+(height-1)+(width-1),
                     dtype={'names': ['a', 'b','w'],'formats': ['i4', 'i4','f4']})
    num = 0
    width_range = range(width)
    height_range = range(height)
    for y in height_range:
        for x in width_range:
            if x < width - 1:
                edges[num]['a'] = y * width + x
                edges[num]['b'] = y * width + (x + 1)
                edges[num]['w'] = diff(smImg3f, x, y, x + 1, y)
                num += 1
            if y < height - 1:
                edges[num]['a'] = y * width + x
                edges[num]['b'] = (y + 1) * width + x
                edges[num]['w'] = diff(smImg3f, x, y, x, y + 1)
                num += 1
            if (x < (width - 1)) and (y < (height - 1)):
                edges[num]['a'] = y * width + x
                edges[num]['b'] = (y + 1) * width + (x + 1)
                edges[num]['w'] = diff(smImg3f, x, y, x + 1, y + 1)
                num += 1
            if (x < (width - 1)) and y > 0:
                edges[num]['a'] = y * width + x
                edges[num]['b'] = (y - 1) * width + (x + 1)
                edges[num]['w'] = diff(smImg3f, x, y, x + 1, y - 1)
                num += 1
    edges = np.sort(edges, order='w')
    u_array = segment_graph(width * height, num, edges, c=20.0, min_size=200)
    marker = {}
    imgIdx = np.zeros((smImg3f.shape[0], smImg3f.shape[1]), np.int32)
    idxNum = 0
    for y in height_range:
        for x in width_range:
            comp = y * width + x
            while (comp != u_array[comp, 1]):
                comp = u_array[comp, 1]
            if comp not in marker.keys():
                marker[comp] = idxNum
                idxNum += 1
            idx = marker[comp]
            imgIdx[y, x] = idx
    return idxNum,