天天看點

A* Search AlgorithmA* Search Algorithm

A* Search Algorithm

寫在前面:第一次寫部落格,以後想把遇到的有意思的算法或者小知識記錄下來。純原創,部分圖檔來自課堂PPT,出自UCR,CS170課件。轉載請聲明出處。

算法簡介

A* Search 是一種啟發式搜尋算法,該算法内置一個函數來表達目前狀态和目标狀态的內插補點。舉個栗子,8-puzzle問題(一個遊戲,目标是将對應數字的版塊放到對應位置)。目前狀态是下圖

A* Search AlgorithmA* Search Algorithm

目标狀态:

A* Search AlgorithmA* Search Algorithm

如果我們以兩個狀态有多少不符合啟發函數,那麼這裡的啟發函數的值是

1(PS:這裡我們排除了空白格,圖中目前态和目标态就一個‘8’不同)

A* Search AlgorithmA* Search Algorithm

接下來會詳細介紹A算法并以8-puzzle遊戲左右一個例子來說明該算法,在引入A search前,我會先簡單介紹兩個算法,Uniform Cost Search和Hill climbing Search。啟發函數選擇的是Manhattan距離,詳解看下圖。(h(n)就是啟發函數)

A* Search AlgorithmA* Search Algorithm

先說說Hill climbing Search,該算法也是一種啟發式搜尋算法,思路很簡單,直接使用Manhattan距離。在下面的例子中,這個算法運作的很好。(h(n)是啟發函數)

A* Search AlgorithmA* Search Algorithm

這個算法運作的很快,可是會遇到問題。有時候該算法在一些狀态會陷入找不到解的狀态

A* Search AlgorithmA* Search Algorithm

另一種Uniform Cost Search算法,該算法更簡單每次在拓展節點的時候,拓展最廉價的點(該點的耗費是到這個點的路徑的累積消耗,拓展累積消耗最小的點)。該算法運作效率較低,但是可以保證每次找到一個解。

A* Search AlgorithmA* Search Algorithm

背景介紹完畢。接下來進入A*算法,這個算法結合了Uniform Cost Search和Hill climbing Search。将到達該節點的累積消耗g(n)和該節點到達目标節點的差h(n)作為啟發函數

A* Search AlgorithmA* Search Algorithm

下圖是一個A算法的例子:這是一個迷宮問題,初始狀态是粉色的點,目标狀态是藍色的點。将粉點移動到藍點即可。圖中的樹說明了用A樹的搜尋過程。H(n)依然使用Manhattan距離

A* Search AlgorithmA* Search Algorithm

說到這,A算法的基本概念已經介紹完畢。一下附上A算法的實作代碼(Python)。以解決8-puzzle問題作為背景。

實作代碼

import numpy as np
import Queue
import copy
DEFAULT_PUZZLE = np.array([1,2,3,4,8,0,7,6,5]).reshape((3,3))
GOAL_STATE = np.array([1,2,3,4,5,6,7,8,0]).reshape((3,3))
STATE_TABLE = dict() # store the state already check

class Node():
    '''
        :description: define the Node class
    '''
    def __init__(self,puzzle,depth,total_cost,blank_index_i,blank_index_j):
        self.puzzle = puzzle
        self.depth = depth
        self.total_cost = total_cost
        self.blank_index_i = blank_index_i
        self.blank_index_j = blank_index_j
    def get_depth(self):
        return self.depth
    def get_puzzle(self):
        return self.puzzle
    def get_total_cost(self):
        return self.total_cost
    def get_blank_index_i(self):
        return self.blank_index_i
    def get_blank_index_j(self):
        return self.blank_index_j
    def __lt__(self, other):
        return self.total_cost < other.total_cost

def Init_input_puzzle():
    '''
    :description:
            Initializing the input puzzle matrix, and choose the algorithm
    :input: None
    :return:
            puzzle: the puzzle need to solved
            key:  the choice of algorithm
            blank_index_i : the blank block index
            blank_index_j : the blank block index
    '''
    print "Welcome to Geyeshui's 8-Puzzle solver"
    key = input("Inpute 1 to use the defualt puzzle, or 2 to enter your own puzzle: \n")
    if key == 1:
        puzzle = DEFAULT_PUZZLE
    else:
        print "Enter your puzzle use a zero to represent the blank"
        puzzle_input = []
        for i in range(3):
            row = raw_input("Enter the "+str(i+1)+"th row, use SPACE bewtween numbers: \n")
            puzzle_input = puzzle_input + [eval(i) for i in row.split()]
        puzzle = np.array(puzzle_input).reshape((3,3))

    print "Enter your choice of algorithm"
    print "1.Uniform cost search"
    print "2.A* with the Misplaced title heuristic"
    key = input("3.A* with the Manhattan distance heuristic \n")

    # find the blank index
    for i in range(3):
        for j in range(3):
            if puzzle[i][j]==0:
                blank_index_i = i
                blank_index_j = j
    return puzzle, key, blank_index_i, blank_index_j

def Get_heuristic_cost(key,puzzle):
    '''
    :description:
            according the algorithm you choice return the corresponding h(n) value
    :inpute:
            key : the algorithm index number
            puzzle : the puzzle needed to estimate the h(n)
    :return:
            h_n : the h(n) value
    '''
    h_n = 0
    if key == 1:
        h_n=0
    elif key ==2:
        for i in range(3):  # calculate the misplace number, excluding the blank.
            for j in range(3):
                if puzzle[(i,j)] != GOAL_STATE[(i,j)]:
                    if i==2 and j==2:
                        continue
                    else:
                        h_n = h_n+1
    else:
        for i in range(3): # calculate the manhattan distance
            for j in range(3):
                num = puzzle[(i,j)]
                if num==0:
                    continue
                else:
                    index_num_i = (num-1)/3
                    index_num_j = (num-1)%3
                    h_n = h_n + (abs(i-index_num_i)+abs(j-index_num_j))
    return h_n

def Is_goal_state(puzzle):
    '''
    :description: return Ture if the puzzle is the goal state, False otherwise
    :input:
            puzzle: the puzzle needs to be check
    :return:
            Ture: if puzzle is the goal state
            False: if it is not
    '''
    if sum(sum(puzzle==GOAL_STATE))==9:
        return True
    else:
        return False

def Move_up(puzzle,index_i,index_j):
    '''
    :description:
            move up the blank block if it can.
    :param
            puzzle: the puzzle which needs the operation
    :return:
            puzzle: puzzle after move up
            True: if it can move up
            False: if it is illegal to move up
    '''
    if index_i>0:
        puzzle[index_i-1][index_j],puzzle[index_i][index_j] \
            = puzzle[index_i][index_j],puzzle[index_i-1][index_j]
        if STATE_TABLE.get(str(puzzle),0) == 1:
            return None,False
        else:
            STATE_TABLE[str(puzzle)] = 1
            return puzzle,True
    else:
        return None,False

def Move_down(puzzle,index_i,index_j):
    '''
    :description:
            move down the blank block if it can.
    :param
            puzzle: the puzzle which needs the operation
    :return:
            puzzle: puzzle after move down
            True: if it can move down
            False: if it is illegal to move down
    '''
    if index_i<2:
        puzzle[index_i+1][index_j],puzzle[index_i][index_j] \
            = puzzle[index_i][index_j],puzzle[index_i+1][index_j]
        if STATE_TABLE.get(str(puzzle),0) == 1:
            return None,False
        else:
            STATE_TABLE[str(puzzle)] = 1
            return puzzle,True
    else:
        return None,False

def Move_left(puzzle,index_i,index_j):
    '''
    :description:
            move left the blank block if it can.
    :param
            puzzle: the puzzle which needs the operation
    :return:
            puzzle: puzzle after move left
            True: if it can move left
            False: if it is illegal to move left
    '''
    if index_j>0:
        puzzle[index_i][index_j-1],puzzle[index_i][index_j] \
            = puzzle[index_i][index_j],puzzle[index_i][index_j-1]
        if STATE_TABLE.get(str(puzzle),0) == 1:
            return None,False
        else:
            STATE_TABLE[str(puzzle)] = 1
            return puzzle,True
    else:
        return None,False

def Move_right(puzzle,index_i,index_j):
    '''
    :description:
            move right the blank block if it can.
    :param
            puzzle: the puzzle which needs the operation
    :return:
            puzzle: puzzle after move right
            True: if it can move right
            False: if it is illegal to move right
    '''
    if index_j<2:
        puzzle[index_i][index_j+1],puzzle[index_i][index_j] \
            = puzzle[index_i][index_j],puzzle[index_i][index_j+1]
        if STATE_TABLE.get(str(puzzle),0) == 1:
            return None,False
        else:
            STATE_TABLE[str(puzzle)] = 1
            return puzzle,True
    else:
        return None,False

if __name__ == '__main__':
    ans = None
    # key is the choice index of algorithm
    puzzle, key, blank_index_i, blank_index_j= Init_input_puzzle()
    STATE_TABLE[str(puzzle)] = 1
    global_step = 0 # store the how many iteration we run
    size_of_pq = 0 # store the max size of priority_queue
    pq = Queue.PriorityQueue()
    pq.put(Node(puzzle,0,Get_heuristic_cost(key,puzzle),blank_index_i,blank_index_j))
    while not pq.empty():
        size_of_pq = max(size_of_pq,pq.qsize())
        node = pq.get()
        global_step = global_step + 1
        print node.get_puzzle()
        if Is_goal_state(node.get_puzzle()):
            ans = node
            break
        else:
            blank_index_i = node.get_blank_index_i()
            blank_index_j = node.get_blank_index_j()

            up_puzzle, up_flag = Move_up(copy.deepcopy(node.get_puzzle()),blank_index_i,blank_index_j)
            down_puzzle, down_flag = Move_down(copy.deepcopy(node.get_puzzle()),blank_index_i,blank_index_j)
            right_puzzle, right_flag = Move_right(copy.deepcopy(node.get_puzzle()),blank_index_i,blank_index_j)
            left_puzzle, left_flag = Move_left(copy.deepcopy(node.get_puzzle()),blank_index_i,blank_index_j)

            if up_flag==True:
                pq.put(Node(up_puzzle,node.get_depth()+1,node.get_depth()+1+Get_heuristic_cost(key,up_puzzle),
                            blank_index_i-1,blank_index_j))

            if down_flag==True:
                pq.put(Node(down_puzzle,node.get_depth()+1,node.get_depth()+1+Get_heuristic_cost(key,down_puzzle),
                            blank_index_i+1,blank_index_j))

            if right_flag==True:
                pq.put(Node(right_puzzle,node.get_depth()+1,node.get_depth()+1+Get_heuristic_cost(key,right_puzzle),
                            blank_index_i,blank_index_j+1))

            if left_flag==True:
                pq.put(Node(left_puzzle,node.get_depth()+1,node.get_depth()+1+Get_heuristic_cost(key,left_puzzle),
                            blank_index_i,blank_index_j-1))
    print ans.get_puzzle(),ans.get_depth(),global_step,size_of_pq

                

以上就是A* Search Algorithm, 有什麼說的不對的地方歡迎指正

繼續閱讀