天天看点

【强化学习】迷宫寻宝:Sarsa和Q-Learning

前言

本篇博文通过迷宫寻宝这一实例来探究Sarsa和Q-Learning的不同。

相关代码主要参考自邹伟等人所著的《强化学习》(清华大学出版社)。.

理论基础

这里简单放一下Sarsa和Q-Learning的更新公式,更详细的内容可参看本专栏后续的知识点整理。

Sarsa:

Q-Learning:

环境介绍

【强化学习】迷宫寻宝:Sarsa和Q-Learning

环境采用可视化工具Tkinter进行绘制,效果如图:

【强化学习】迷宫寻宝:Sarsa和Q-Learning

前置知识扩充

代码中​

​Q表格​

​主要通过pandas的DataFrame数据结构来进行实现,由于笔者对该结构了解不深,特用下面的代码来做个实验,以便对DataFrame有个初步了解。

import pandas as pd
import numpy as np

table = pd.DataFrame(columns=['u', 'd', 'l', 'r'], dtype=np.float64)
table = table.append(
    pd.Series(
        [1] * 4,
        index=table.columns,
        name=1))

table = table.append(
    pd.Series(
        [0] * 4,
        index=table.columns,
        name=2))

table = table.append(
    pd.Series(
        [0] * 4,
        index=table.columns,
        name=3))
print(table)
predict = table.loc[1, "d"]
print(predict)      

输出:

u    d    l    r
1  1.0  1.0  1.0  1.0
2  0.0  0.0  0.0  0.0
3  0.0  0.0  0.0  0.0
1.0      

首先创建了一个table,u,d,l,r代表四个动作(上下左右),columns将这四个值设置为表格的列标签。

然后以Series的形式向表格内插入数据,第一个值是数据值,第二个index是列索引,第三个name是行标签,即Q表格的状态。

通过loc函数可以获得表格中的任意值,第一个是行标签,第二个是列标签。

程序流程

Sarsa和Q-Learning两种方法的流程几乎是一样的,主要区别在于Q值的更新公式不一样。下面就用语言描述一下算法流程。

Step1:初始化环境

env = Maze()      

Step2:初始化Q表格

RL = SarsaTable(actions=list(range(env.n_actions)))      

Step3:设定100幕迭代,每次迭代首先初始化状态,即将初始位置放在左上角。

observation = env.reset()      

Step4:基于当前状态选择动作,这里采用的是epsilon-贪心选择,epsilon取值为0.9,即每次有90%的概率选择当前状态的最优动作,10%的概率进行随机选择,即探索。选择前,先检查该状态是否在Q表格中存在,不存在就添加。

action = RL.choose_action(str(observation))

   def choose_action(self, observation):
        self.check_state_exist(observation)
        # 从均匀分布的[0,1)中随机采样,当小于阈值时采用选择最优行为的方式,当大于阈值选择随机行为的方式,这样人为增加随机性是为了解决陷入局部最优
        if np.random.rand() < self.epsilon:
            # 选择最优行为
            state_action = self.q_table.loc[observation, :]
            # 因为一个状态下最优行为可能会有多个,所以在碰到这种情况时,需要随机选择一个行为进行
            state_action = state_action.reindex(np.random.permutation(state_action.index))
            action = state_action.idxmax()
        else:
            # 选择随机行为
            action = np.random.choice(self.actions)
        return      

Step5:保存临时策略,策略即当前状态下的选择的动作,在程序中可以理解为一个字典,键就是当前状态,键值就是动作。

tmp_policy[state_item] =      

Step6:采取动作并获得下一个状态和回报以及是否终止信息

, reward, done, oval_flag = env.step(action)      

Step6.5(这一步只有Sarsa有,Q-Learning没有):再次获取下一个动作,由于Sarsa需要五个值,因此还需要根据下一个状态来再次选择一次动作而Q-Learning不需要再次进行动作选择(体现了离轨策略的思想)。

action_ = RL.choose_action(str(observation_))      

Step7:更新Q表格,这一步是两者区别的关键,前面提到两者的更新公式不一样,这里用程序来表达一下。

Sarsa:

# 同轨策略Sarsa
class SarsaTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 使用公式: Q_taget = r+γQ(s',a')
            q_target = r + self.gamma * self.q_table.loc[s_, a_]
        else:
            q_target = r
        # 更新公式: Q(s,a)←Q(s,a)+α(r+γQ(s',a')-Q(s,a))
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)      

Q-learning:

# 离轨策略Q-learning
class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 使用公式:Q_target = r+γ  maxQ(s',a')
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal
        else:
            q_target = r

        # 更新公式: Q(s,a)←Q(s,a)+α(r+γ  maxQ(s',a')-Q(s,a))
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)      

可以发现,两者的区别就在于下一时刻的动作a‘如何选择。Sarsa和第一次选择动作一样,再次进行动作选择;而Q-Learning直接基于下一个状态S’,在Q表格中选择最大价值的动作。

这里做简单的一个分析。以迷宫为例,里面存在多个陷阱。如果进行动作的epsilon-贪心选择,则有更大几率调入陷阱,从而影响第一步Q值的更新,这样就会导致智能体”畏首畏尾“。而Q-Learning第二步进行贪心选择,则不容易落入陷阱中,从而会使智能体更为路径规划更为大胆。所以从这样的直观角度理解,Q-Learning的效果应该会比Sarsa要好。

Step8:先判断是否到达终止状态,若到达,结束这一幕,并再次判断是否收敛;这里收敛的条件设为三次策略policy不变化,如果不收敛,将临时的策略进行保存;如果收敛,跳出循环,结束操作。

# 如果为终止状态,结束当前的局数
if done:
    episode_num = episode
    step_num += c
    print(policy)
    print("-" * 50)
    # 如果N次行走的策略相同,表示已经收敛
    if policy == tmp_policy and oval_flag:
        count = count + 1
        if count == N:
            flag = True
    else:
        count = 0
        policy = tmp_policy
    break      

效果展示

Sarsa结果:

【强化学习】迷宫寻宝:Sarsa和Q-Learning

这里可以发现,即使策略收敛,依旧花费了比较长的时间。而且最终的结果存在问题,运行多次,结果不稳定,有时候在100局内无法收敛。Q-Learning结果:

【强化学习】迷宫寻宝:Sarsa和Q-Learning

可以看到Q-Learning找到了最佳的路径,并且用时不长。这和前面的直观分析是吻合的。

完整代码

maze.py(迷宫环境)

import numpy as np
import time
import sys

if sys.version_info.major == 2:
    import Tkinter as tk
else:
    import tkinter as tk

UNIT = 40  # 每个格子的大小
MAZE_H = 5  # 行数
MAZE_W = 5  # 列数


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.nS = np.prod([MAZE_H, MAZE_W])
        self.n_actions = len(self.action_space)
        self.title('寻宝')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self._build_maze()

    def _build_maze(self):
        # 创建一个画布
        self.canvas = tk.Canvas(self, bg='white',
                                height=MAZE_H * UNIT,
                                width=MAZE_W * UNIT)

        # 在画布上画出列
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        # 在画布上画出行
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建探险者起始位置(默认为左上角)
        origin = np.array([20, 20])

        # 陷阱1
        hell1_center = origin + np.array([UNIT, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - 15, hell1_center[1] - 15,
            hell1_center[0] + 15, hell1_center[1] + 15,
            fill='black')
        # 陷阱2
        hell2_center = origin + np.array([UNIT * 2, UNIT])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - 15, hell2_center[1] - 15,
            hell2_center[0] + 15, hell2_center[1] + 15,
            fill='black')

        # 陷阱3
        hell3_center = origin + np.array([UNIT * 3, UNIT])
        self.hell3 = self.canvas.create_rectangle(
            hell3_center[0] - 15, hell3_center[1] - 15,
            hell3_center[0] + 15, hell3_center[1] + 15,
            fill='black')

        # 陷阱4
        hell4_center = origin + np.array([UNIT, UNIT * 3])
        self.hell4 = self.canvas.create_rectangle(
            hell4_center[0] - 15, hell4_center[1] - 15,
            hell4_center[0] + 15, hell4_center[1] + 15,
            fill='black')

        # 陷阱5
        hell5_center = origin + np.array([UNIT * 3, UNIT * 3])
        self.hell5 = self.canvas.create_rectangle(
            hell5_center[0] - 15, hell5_center[1] - 15,
            hell5_center[0] + 15, hell5_center[1] + 15,
            fill='black')

        # 陷阱6
        hell6_center = origin + np.array([0, UNIT * 4])
        self.hell6 = self.canvas.create_rectangle(
            hell6_center[0] - 15, hell6_center[1] - 15,
            hell6_center[0] + 15, hell6_center[1] + 15,
            fill='black')

        # 陷阱7
        hell7_center = origin + np.array([UNIT * 4, UNIT * 4])
        self.hell7 = self.canvas.create_rectangle(
            hell7_center[0] - 15, hell7_center[1] - 15,
            hell7_center[0] + 15, hell7_center[1] + 15,
            fill='black')

        # 宝藏位置
        oval_center = origin + np.array([UNIT * 2, UNIT * 4])
        self.oval = self.canvas.create_oval(
            oval_center[0] - 15, oval_center[1] - 15,
            oval_center[0] + 15, oval_center[1] + 15,
            fill='yellow')

        # 将探险者用矩形表示
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')

        # 画布展示
        self.canvas.pack()

    # 根据当前的状态重置画布(为了展示动态效果)
    def reset(self):
        self.update()
        time.sleep(0.5)
        self.canvas.delete(self.rect)
        origin = np.array([20, 20])
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')
        return self.canvas.coords(self.rect)

    # 根据当前行为,确认下一步的位置
    def step(self, action):
        s = self.canvas.coords(self.rect)
        base_action = np.array([0, 0])
        if action == 0:  # 上
            if s[1] > UNIT:
                base_action[1] -= UNIT
        elif action == 1:  # 下
            if s[1] < (MAZE_H - 1) * UNIT:
                base_action[1] += UNIT
        elif action == 2:  # 左
            if s[0] > UNIT:
                base_action[0] -= UNIT
        elif action == 3:  # 右
            if s[0] < (MAZE_W - 1) * UNIT:
                base_action[0] += UNIT

        # 在画布上将探险者移动到下一位置
        self.canvas.move(self.rect, base_action[0], base_action[1])
        # 重新渲染整个界面
        s_ = self.canvas.coords(self.rect)
        oval_flag = False

        # 根据当前位置来获得回报值,及是否终止
        if s_ == self.canvas.coords(self.oval):
            reward = 1
            done = True
            s_ = 'terminal'
            oval_flag = True
        elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2), self.canvas.coords(self.hell3),
                    self.canvas.coords(self.hell4), self.canvas.coords(self.hell5), self.canvas.coords(self.hell6),
                    self.canvas.coords(self.hell7)]:
            reward = -1
            done = True
            s_ = 'terminal'
        else:
            reward = 0
            done = False

        return s_, reward, done, oval_flag

    def render(self):
        time.sleep(0.1)
        self.update()

    # 根据传入策略进行界面的渲染
    def render_by_policy(self, policy):
        cal_policy = sorted(policy)

        pre_x, pre_y = 20, 20

        for state in cal_policy:
            x = (state[0] + state[2]) / 2
            y = (state[1] + state[3]) / 2

            self.canvas.create_line(pre_x, pre_y, x, y, fill="red", tags="line", width=5)

            pre_x = x
            pre_y = y

        # 连接到宝藏位置
        oval_center = [20, 20] + np.array([UNIT * 2, UNIT * 4])

        self.canvas.create_line(pre_x, pre_y, oval_center[0], oval_center[1], fill="red", tags="line", width=5)

        self.render()

    def render_by_policy_new(self, policy):
        for i in range(MAZE_W):
            rows_obj = policy[i]
            for j in range(MAZE_H):
                item_center_x, item_center_y = (j * UNIT + UNIT / 2), (i * UNIT + UNIT / 2)

                cols_obj = rows_obj[j]

                if cols_obj == -1:
                    continue

                for item in cols_obj:
                    if item == 0:
                        item_x = item_center_x
                        item_y = item_center_y - 15.0
                        self.canvas.create_line(item_center_x, item_center_y, item_x, item_y, fill="black", width=1,
                                                arrow='last')
                    elif item == 1:
                        item_x = item_center_x
                        item_y = item_center_y + 15.0
                        self.canvas.create_line(item_center_x, item_center_y, item_x, item_y, fill="black", width=1,
                                                arrow='last')
                    elif item == 2:
                        item_x = item_center_x - 15.0
                        item_y = item_center_y
                        self.canvas.create_line(item_center_x, item_center_y, item_x, item_y, fill="black", width=1,
                                                arrow='last')
                    elif item == 3:
                        item_x = item_center_x + 15.0
                        item_y = item_center_y
                        self.canvas.create_line(item_center_x, item_center_y, item_x, item_y, fill="black", width=1,
                                                arrow='last')
        self.render()      
import numpy as np
import pandas as pd


class RL(object):
    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = action_space
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # 如果状态在当前的Q表中不存在,将当前状态加入Q表中
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

    def choose_action(self, observation):
        self.check_state_exist(observation)
        # 从均匀分布的[0,1)中随机采样,当小于阈值时采用选择最优行为的方式,当大于阈值选择随机行为的方式,这样人为增加随机性是为了解决陷入局部最优
        if np.random.rand() < self.epsilon:
            # 选择最优行为
            state_action = self.q_table.loc[observation, :]
            # 因为一个状态下最优行为可能会有多个,所以在碰到这种情况时,需要随机选择一个行为进行
            state_action = state_action.reindex(np.random.permutation(state_action.index))
            action = state_action.idxmax()
        else:
            # # 选择随机行为
            action = np.random.choice(self.actions)
        return action

    def learn(self, *args):
        pass


# 离轨策略Q-learning
class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 使用公式:Q_target = r+γ  maxQ(s',a')
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal
        else:
            q_target = r

        # 更新公式: Q(s,a)←Q(s,a)+α(r+γ  maxQ(s',a')-Q(s,a))
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)


# 同轨策略Sarsa
class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 使用公式: Q_taget = r+γQ(s',a')
            q_target = r + self.gamma * self.q_table.loc[s_, a_]
        else:
            q_target = r
        # 更新公式: Q(s,a)←Q(s,a)+α(r+γQ(s',a')-Q(s,a))
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)      
import sys

if "../" not in sys.path:
    sys.path.append("../")
from lib.envs.maze import Maze
from RL_brain import QLearningTable, SarsaTable
import numpy as np

# METHOD = "SARSA"
METHOD = "Q-Learning"


def get_action(q_table, state):
    # 选择最优行为
    state_action = q_table.loc[state, :]

    # 因为一个状态下最优行为可能会有多个,所以在碰到这种情况时,需要随机选择一个行为进行
    state_action_max = state_action.max()

    idxs = []

    for max_item in range(len(state_action)):
        if state_action[max_item] == state_action_max:
            idxs.append(max_item)

    sorted(idxs)
    return tuple(idxs)


def get_policy(q_table, rows=5, cols=5, pixels=40, orign=20):
    policy = []

    for i in range(rows):
        for j in range(cols):
            # 求出每个各自的状态
            item_center_x, item_center_y = (j * pixels + orign), (i * pixels + orign)
            item_state = [item_center_x - 15.0, item_center_y - 15.0, item_center_x + 15.0, item_center_y + 15.0]

            # 如果当前状态为各终止状态,则值为-1
            if item_state in [env.canvas.coords(env.hell1), env.canvas.coords(env.hell2),
                              env.canvas.coords(env.hell3), env.canvas.coords(env.hell4),
                              env.canvas.coords(env.hell5), env.canvas.coords(env.hell6),
                              env.canvas.coords(env.hell7), env.canvas.coords(env.oval)]:
                policy.append(-1)
                continue

            if str(item_state) not in q_table.index:
                policy.append((0, 1, 2, 3))
                continue

            # 选择最优行为
            item_action_max = get_action(q_table, str(item_state))

            policy.append(item_action_max)

    return policy


def update():
    for episode in range(100):
        # 初始化状态
        observation = env.reset()

        c = 0

        tmp_policy = {}

        while True:
            # 渲染当前环境
            env.render()

            # 基于当前状态选择行为
            action = RL.choose_action(str(observation))

            state_item = tuple(observation)

            tmp_policy[state_item] = action

            # 采取行为获得下一个状态和回报,及是否终止
            observation_, reward, done, oval_flag = env.step(action)

            if METHOD == "SARSA":
                # 基于下一个状态选择行为
                action_ = RL.choose_action(str(observation_))

                # 基于变化 (s, a, r, s, a)使用Sarsa进行Q的更新
                RL.learn(str(observation), action, reward, str(observation_), action_)
            elif METHOD == "Q-Learning":
                # 根据当前的变化开始更新Q
                RL.learn(str(observation), action, reward, str(observation_))

            # 改变状态和行为
            observation = observation_
            c += 1
            # 如果为终止状态,结束当前的局数
            if done:
                break
    print('游戏结束')
    # 开始输出最终的Q表
    q_table_result = RL.q_table
    # 使用Q表输出各状态的最优策略
    policy = get_policy(q_table_result)
    print("最优策略为", end=":")
    print(policy)
    print("迷宫格式为", end=":")
    policy_result = np.array(policy).reshape(5, 5)
    print(policy_result)
    print("根据求出的最优策略画出方向")
    env.render_by_policy_new(policy_result)

    # env.destroy()


if __name__ == "__main__":
    env = Maze()
    RL = SarsaTable(actions=list(range(env.n_actions)))
    if METHOD == "Q-Learning":
        RL = QLearningTable(actions=list(range(env.n_actions)))
    env.after(100, update)
    env.mainloop()      

继续阅读