Skip to content

Questions about the documentation related to the RNN #122

@HawkQ

Description

@HawkQ

使用多智能体MAPPO算法时,想要尝试用Basic_RNN替换Basic_MLP,在配置文件中同步修改use_rnn: True后,出现错误提示:

Traceback (most recent call last):
File "/Users/hawkq/Desktop/frigatebird_multi/new_run.py", line 22, in <module>
Agent.train(configs.running_steps // configs.parallels) # Train the model for numerous steps.
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/agents/core/on_policy_marl.py", line 287, in train
self.run_episodes(None, n_episodes=self.n_envs, test_mode=False)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/agents/core/on_policy_marl.py", line 384, in run_episodes
policy_out = self.action(obs_dict=obs_dict, state=state, avail_actions_dict=avail_actions,
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/agents/multi_agent_rl/mappo_agents.py", line 141, in action
rnn_hidden_critic_new, values_out = self.policy.get_values(observation=critic_input,
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/policies/gaussian_marl.py", line 176, in get_values
outputs = self.critic_representation[key](observation[key], *rnn_hidden[key])
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/representations/rnn.py", line 63, in forward
output, hn = self.rnn(mlp_output, h)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1117, in forward
raise RuntimeError(
RuntimeError: For unbatched 2-D input, hx should also be 2-D but got 3-D tensor

应该是数据维度的问题,查阅文档后并未发现有相关部分的说明,不知道还需修改环境代码的其他什么位置,以下是我动作、状态、观察空间:

        self.state_space = Box(-np.inf, np.inf, shape=[7 * self.num_agents, ], dtype=np.float32)
        self.observation_space = {agent: Box(-np.inf, np.inf, shape=[14, ], dtype=np.float32) for agent in self.agents}
        self.action_space = {agent: Box(-1, 1, shape=[2, ], dtype=np.float32) for agent in self.agents}

请问还有哪里需要做出调整,谢谢答疑!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions