thu-ml/tianshou

Don't pass envpool envs where vectorenvs are needed

Open

#1,096 opened on 2024年4月3日

GitHub で見る
 (0 comments) (0 reactions) (0 assignees)Python (7,121 stars) (1,072 forks)batch import
buggood first issuerefactoring

説明

See the block comments in test and in Collector method. Somewhere a pure envpool-env is passed instead of instances of BaseVectorEnv, thus the interface is not followed.

This means we rely on the two interfaces accidentally kind-of coinciding. They already don't fully coincide since envpool envs return an info as single dict with arrays, whereas tianshou's VectorEnvs return an array of dicts.

@Trinkle23897 this issue might be of interest to you

@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
    num_envs = 4
    env = VectorEnvNormObs(
        envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True),
    )
    obs, info = env.reset()
    assert obs.shape[0] == num_envs
    # This is not actually unreachable b/c envpool does not return info in the right format
    if isinstance(info, dict):  # type: ignore[unreachable]
        for _, v in info.items():  # type: ignore[unreachable]
            if not isinstance(v, dict):
                assert v.shape[0] == num_envs
    else:
        for _info in info:
            for _, v in _info.items():
                if not isinstance(v, dict):
                    assert v.shape[0] == num_envs
    def reset_env(
        self,
        gym_reset_kwargs: dict[str, Any] | None = None,
    ) -> None:
        """Reset the environments and the initial obs, info, and hidden state of the collector."""
        gym_reset_kwargs = gym_reset_kwargs or {}
        self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs)
        # TODO: hack, wrap envpool envs such that they don't return a dict
        if isinstance(self._pre_collect_info_R, dict):  # type: ignore[unreachable]
            # this can happen if the env is an envpool env. Then the thing returned by reset is a dict
            # with array entries instead of an array of dicts
            # We use Batch to turn it into an array of dicts
            self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R)  # type: ignore[unreachable]

        self._pre_collect_hidden_state_RH = None

コントリビューターガイド