天天看點

2021-10-17 5.7

class SARSA():
    def __init__(self, env, num_episodes, discount=1.0, alpha=0.5, epsilon=0.1, n_bins=10):
        self.nA = env.action_space.n  # 動作空間數
        self.nS = env.observation_space.shape[0]  # 狀态空間數
        self.env = env  # 環境
        self.num_episodes = num_episodes  # 疊代次數
        self.epsilon = epsilon  # 貪婪政策系數
        self.discount = discount  # 折扣因子
        self.alpha = alpha  # 時間差分誤差系數,即學習率
        self.Q = defaultdict(lambda: np.zeros(self.nA))  # 動作值函數

    def __epislon_greedy_policy(self, epsilon, nA):  # 貪婪政策
        def policy(state):
            A = np.ones(nA, dtype=float) * epsilon / nA
            best_action = np.argmax(self.Q[state])
            A[best_action] += (1 - epsilon)
            return A
        return policy

    def __next_action(self, prob):  # 動作選擇函數
        return np.random.choice(np.arange(len(prob)), p=prob)

    def sarsa(self):
        policy = self.__epislon_greedy_policy(self.epsilon, self.nA)  # 定義政策
        sumlist = []
        for i_episode in range(self.num_episodes):  # 疊代經驗政策
            step = 0
            state__ = self.env.reset()  # 初始化狀态
            state = self.__get_bins_states(state__)  # 狀态重新指派(可簡化狀态空間,将連續狀态近似為離散情況)
            prob_actions = policy(state)  # 獲得動作選擇機率
            action = self.__next_action(prob_actions)  # 獲得動作
            while(True):
                next_state__, reward, done, info = env.step(action)  # 獲得下一狀态,獎勵,done等
                next_state = self.__get_bins_states(next_state__)
                prob_next_actions = policy(next_state)  # 得到下一狀态的動作機率
                next_action = self.__next_action(prob_next_actions)  # 獲得下一動作
                # 時間差分更新
                td_target = reward + self.discount * self.Q[next_state][next_action]
                td_delta = td_target - self.Q[state][action]
                self.Q[state][action] += self.alpha * td_delta
                if done:
                    reward = -200
                    break
                else:
                    state = next_state
                    action = next_action
        return self.Q