天天看點

gym庫中from gym.wrappers import FlattenObservation的了解

看代碼的過程中看到有這樣的調用:

from gym.wrappers import FlattenObservation

if sinstance(env.observation_space, gym.spaces.Dict):

     env = FlattenObservation(env)

不是很了解這個代碼的意思。

===============================================

檢視gym源碼中類:

FlattenObservation(ObservationWrapper)

import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper


class FlattenObservation(ObservationWrapper):
    r"""Observation wrapper that flattens the observation."""
    def __init__(self, env):
        super(FlattenObservation, self).__init__(env)

        flatdim = spaces.flatdim(env.observation_space)
        self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)

    def observation(self, observation):
        return spaces.flatten(self.env.observation_space, observation)      

從gym的狀态空間的轉換可以看出這個類是要将observation的狀态空間進行flatten操作。

具體的flatten操作調用:

spaces.flatten(self.env.observation_space, observation)      

檢視spaces.flatten源代碼:

def flatten(space, x):
    if isinstance(space, Box):
        return np.asarray(x, dtype=np.float32).flatten()
    elif isinstance(space, Discrete):
        onehot = np.zeros(space.n, dtype=np.float32)
        onehot[x] = 1.0
        return onehot
    elif isinstance(space, Tuple):
        return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
    elif isinstance(space, Dict):
        return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
    elif isinstance(space, MultiBinary):
        return np.asarray(x).flatten()
    elif isinstance(space, MultiDiscrete):
        return np.asarray(x).flatten()
    else:
        raise NotImplementedError      

可以知道如果 env.observation_space屬于Box類型,則直接調用np.array的flatten操作。

如果 env.observation_space屬于Discrete類型,則直接進行onehot編碼的方法進行flatten操作。

env.observation_space如果屬于多個Box類型或Discrete類型組合而成的,也就是屬于Tuple, Dict, 那麼需要将其中的每個類型的狀态空間都進行flatten操作後在進行拼接操作。

即:(取出組合空間中的各個子狀态空間疊代調用flatten操作進而實作對組合中的各個子observation_space進行flatten)

elif isinstance(space, Tuple):
        return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
    elif isinstance(space, Dict):
        return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])      

MultiBinary, MultiDiscrete類型直接轉為np.array類型的資料再進行flatten操作。

繼續閱讀