-
Notifications
You must be signed in to change notification settings - Fork 152
Open
Description
使用多智能体MAPPO算法时,想要尝试用Basic_RNN替换Basic_MLP,在配置文件中同步修改use_rnn: True后,出现错误提示:
Traceback (most recent call last):
File "/Users/hawkq/Desktop/frigatebird_multi/new_run.py", line 22, in <module>
Agent.train(configs.running_steps // configs.parallels) # Train the model for numerous steps.
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/agents/core/on_policy_marl.py", line 287, in train
self.run_episodes(None, n_episodes=self.n_envs, test_mode=False)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/agents/core/on_policy_marl.py", line 384, in run_episodes
policy_out = self.action(obs_dict=obs_dict, state=state, avail_actions_dict=avail_actions,
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/agents/multi_agent_rl/mappo_agents.py", line 141, in action
rnn_hidden_critic_new, values_out = self.policy.get_values(observation=critic_input,
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/policies/gaussian_marl.py", line 176, in get_values
outputs = self.critic_representation[key](observation[key], *rnn_hidden[key])
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/xuance/torch/representations/rnn.py", line 63, in forward
output, hn = self.rnn(mlp_output, h)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/xuance_marl/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1117, in forward
raise RuntimeError(
RuntimeError: For unbatched 2-D input, hx should also be 2-D but got 3-D tensor
应该是数据维度的问题,查阅文档后并未发现有相关部分的说明,不知道还需修改环境代码的其他什么位置,以下是我动作、状态、观察空间:
self.state_space = Box(-np.inf, np.inf, shape=[7 * self.num_agents, ], dtype=np.float32)
self.observation_space = {agent: Box(-np.inf, np.inf, shape=[14, ], dtype=np.float32) for agent in self.agents}
self.action_space = {agent: Box(-1, 1, shape=[2, ], dtype=np.float32) for agent in self.agents}
请问还有哪里需要做出调整,谢谢答疑!
Metadata
Metadata
Assignees
Labels
No labels