天天看點

2021-10-17 4.13

def create_random_policy(nA):
    A = np.ones(nA, dtype=float) / nA  # 建立随即政策
    
    def policy_fn(observation):  # 政策函數
        return A
    return policy_fn


def create_greedy_policy(Q):
    def policy_fn(state):  # 建立貪婪政策
        A = np.zeros_like(Q[state], dtype=float)
        best_action = np.argmax(Q[state])
        A[best_action] = 1.0
        return A
    return policy_fn


def mc_control_importance_sampling(env, num_episode, behavior_policy, discount_factor=1.0):
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    target_policy = create_greedy_policy(Q)  # 初始化目标政策
    for i_episode in range(1, num_episode+1):
        episode = []
        state = env.reset()
        while(True):
            probs = behavior_policy(state)  # 從行為政策中采樣得到目前狀态的機率
            action = np.random.choice(np.arange(len(probs)), p=probs)  # 按照動作機率選擇動作
            next_state, reward, done, _ = env.step(action)  # 執行動作,得到狀态,獎勵
            episode.append((state, action, reward))
            if done:
                break
            state = next_state
        G = 0.0  # 未來折扣獎勵
        W = 1.0  # 重要性權重參數
        for t in range(len(episode))[::-1]:  # 在該經驗軌迹中從最後的時間步開始周遊
            state, action, reward = episode[t]  # 獲得目前經驗軌迹的目前步
            G = discount_factor * G + reward  # 更新累計獎勵
            C[state][action] += W  # 更新累計權重
            Q[state][action] += (W / C[state][action]) * (G - Q[state][action])  # 更新動作值函數
            if action != np.argmax(target_policy(state)):
                break
            W = W * 1. / behavior_policy(state)[action]  # 根據行為政策更新重要性權重參數
    return Q, target_policy