看代碼的過程中看到有這樣的調用:
from gym.wrappers import FlattenObservation
if sinstance(env.observation_space, gym.spaces.Dict):
env = FlattenObservation(env)
不是很了解這個代碼的意思。
===============================================
檢視gym源碼中類:
FlattenObservation(ObservationWrapper)
import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper
class FlattenObservation(ObservationWrapper):
r"""Observation wrapper that flattens the observation."""
def __init__(self, env):
super(FlattenObservation, self).__init__(env)
flatdim = spaces.flatdim(env.observation_space)
self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)
def observation(self, observation):
return spaces.flatten(self.env.observation_space, observation)
從gym的狀态空間的轉換可以看出這個類是要将observation的狀态空間進行flatten操作。
具體的flatten操作調用:
spaces.flatten(self.env.observation_space, observation)
檢視spaces.flatten源代碼:
def flatten(space, x):
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).flatten()
elif isinstance(space, Discrete):
onehot = np.zeros(space.n, dtype=np.float32)
onehot[x] = 1.0
return onehot
elif isinstance(space, Tuple):
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
elif isinstance(space, Dict):
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
return np.asarray(x).flatten()
elif isinstance(space, MultiDiscrete):
return np.asarray(x).flatten()
else:
raise NotImplementedError
可以知道如果 env.observation_space屬于Box類型,則直接調用np.array的flatten操作。
如果 env.observation_space屬于Discrete類型,則直接進行onehot編碼的方法進行flatten操作。
env.observation_space如果屬于多個Box類型或Discrete類型組合而成的,也就是屬于Tuple, Dict, 那麼需要将其中的每個類型的狀态空間都進行flatten操作後在進行拼接操作。
即:(取出組合空間中的各個子狀态空間疊代調用flatten操作進而實作對組合中的各個子observation_space進行flatten)
elif isinstance(space, Tuple):
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
elif isinstance(space, Dict):
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
MultiBinary, MultiDiscrete類型直接轉為np.array類型的資料再進行flatten操作。