Skip to content

Commit 2eb6064

Browse files
committed
Add TRPO, D3PG and SAC, minor improvements and bug fixes
1 parent 758f7e4 commit 2eb6064

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2456
-229
lines changed

README.md

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,36 @@ Welcome to `actorch`, a deep reinforcement learning framework for fast prototypi
1414
- [REINFORCE](https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)
1515
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
1616
- [Actor-Critic Kronecker-Factored Trust Region (ACKTR)](https://arxiv.org/abs/1708.05144)
17+
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
1718
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
1819
- [Advantage-Weighted Regression (AWR)](https://arxiv.org/abs/1910.00177)
1920
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
21+
- [Distributional Deep Deterministic Policy Gradient (D3PG)](https://arxiv.org/abs/1804.08617)
2022
- [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/abs/1802.09477)
23+
- [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290)
2124

2225
---------------------------------------------------------------------------------------------------------
2326

2427
## 💡 Key features
2528

2629
- Support for [OpenAI Gymnasium](https://gymnasium.farama.org/) environments
27-
- Support for custom observation/action spaces
28-
- Support for custom multimodal input multimodal output models
29-
- Support for recurrent models (e.g. RNNs, LSTMs, GRUs, etc.)
30-
- Support for custom policy/value distributions
31-
- Support for custom preprocessing/postprocessing pipelines
32-
- Support for custom exploration strategies
30+
- Support for **custom observation/action spaces**
31+
- Support for **custom multimodal input multimodal output models**
32+
- Support for **recurrent models** (e.g. RNNs, LSTMs, GRUs, etc.)
33+
- Support for **custom policy/value distributions**
34+
- Support for **custom preprocessing/postprocessing pipelines**
35+
- Support for **custom exploration strategies**
3336
- Support for [normalizing flows](https://arxiv.org/abs/1906.02771)
3437
- Batched environments (both for training and evaluation)
35-
- Batched trajectory replay
36-
- Batched and distributional value estimation (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561))
37-
- Data parallel and distributed data parallel multi-GPU training and evaluation
38-
- Automatic mixed precision training
39-
- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and hyperparameter tuning at any scale
40-
- Effortless experiment definition through Python-based configuration files
41-
- Built-in visualization tool to plot performance metrics
42-
- Modular object-oriented design
43-
- Detailed API documentation
38+
- Batched **trajectory replay**
39+
- Batched and **distributional value estimation** (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561))
40+
- Data parallel and distributed data parallel **multi-GPU training and evaluation**
41+
- Automatic **mixed precision training**
42+
- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and **hyperparameter tuning** at any scale
43+
- Effortless experiment definition through **Python-based configuration files**
44+
- Built-in **visualization tool** to plot performance metrics
45+
- Modular **object-oriented** design
46+
- Detailed **API documentation**
4447

4548
---------------------------------------------------------------------------------------------------------
4649

@@ -161,7 +164,7 @@ experiment_params = ExperimentParams(
161164
enable_amp=False,
162165
enable_reproducibility=True,
163166
log_sys_usage=True,
164-
suppress_warnings=False,
167+
suppress_warnings=True,
165168
),
166169
)
167170
```
@@ -197,6 +200,8 @@ You can find the generated plots in `plots`.
197200

198201
Congratulations, you ran your first experiment!
199202

203+
See `examples` for additional configuration file examples.
204+
200205
**HINT**: since a configuration file is a regular Python script, you can use all the
201206
features of the language (e.g. inheritance).
202207

@@ -217,6 +222,21 @@ features of the language (e.g. inheritance).
217222

218223
---------------------------------------------------------------------------------------------------------
219224

225+
## @ Citation
226+
227+
```
228+
@misc{DellaLibera2022ACTorch,
229+
author = {Luca Della Libera},
230+
title = {{ACTorch}: a Deep Reinforcement Learning Framework for Fast Prototyping},
231+
year = {2022},
232+
publisher = {GitHub},
233+
journal = {GitHub repository},
234+
howpublished = {\url{https://github.com/lucadellalib/actorch}},
235+
}
236+
```
237+
238+
---------------------------------------------------------------------------------------------------------
239+
220240
## 📧 Contact
221241

222242

actorch/algorithms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from actorch.algorithms.acktr import *
2121
from actorch.algorithms.algorithm import *
2222
from actorch.algorithms.awr import *
23+
from actorch.algorithms.d3pg import *
2324
from actorch.algorithms.ddpg import *
2425
from actorch.algorithms.ppo import *
2526
from actorch.algorithms.reinforce import *
27+
from actorch.algorithms.sac import *
2628
from actorch.algorithms.td3 import *
29+
from actorch.algorithms.trpo import *

actorch/algorithms/a2c.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
"""Advantage Actor-Critic."""
17+
"""Advantage Actor-Critic (A2C)."""
1818

1919
import contextlib
2020
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@@ -38,7 +38,7 @@
3838
DistributedDataParallelREINFORCE,
3939
LRScheduler,
4040
)
41-
from actorch.algorithms.utils import prepare_model
41+
from actorch.algorithms.utils import normalize_, prepare_model
4242
from actorch.algorithms.value_estimation import n_step_return
4343
from actorch.distributions import Deterministic
4444
from actorch.envs import BatchedEnv
@@ -65,7 +65,7 @@
6565

6666

6767
class A2C(REINFORCE):
68-
"""Advantage Actor-Critic.
68+
"""Advantage Actor-Critic (A2C).
6969
7070
References
7171
----------
@@ -208,12 +208,8 @@ def setup(self, config: "Dict[str, Any]") -> "None":
208208
self.config = A2C.Config(**self.config)
209209
self.config["_accept_kwargs"] = True
210210
super().setup(config)
211-
self._value_network = (
212-
self._build_value_network().train().to(self._device, non_blocking=True)
213-
)
214-
self._value_network_loss = (
215-
self._build_value_network_loss().train().to(self._device, non_blocking=True)
216-
)
211+
self._value_network = self._build_value_network()
212+
self._value_network_loss = self._build_value_network_loss()
217213
self._value_network_optimizer = self._build_value_network_optimizer()
218214
self._value_network_optimizer_lr_scheduler = (
219215
self._build_value_network_optimizer_lr_scheduler()
@@ -324,16 +320,20 @@ def _build_value_network(self) -> "Network":
324320
self.value_network_normalizing_flows,
325321
)
326322
self._log_graph(value_network.wrapped_model.model, "value_network_model")
327-
return value_network
323+
return value_network.train().to(self._device, non_blocking=True)
328324

329325
def _build_value_network_loss(self) -> "Loss":
330326
if self.value_network_loss_builder is None:
331327
self.value_network_loss_builder = torch.nn.MSELoss
332328
if self.value_network_loss_config is None:
333329
self.value_network_loss_config: "Dict[str, Any]" = {}
334-
return self.value_network_loss_builder(
335-
reduction="none",
336-
**self.value_network_loss_config,
330+
return (
331+
self.value_network_loss_builder(
332+
reduction="none",
333+
**self.value_network_loss_config,
334+
)
335+
.train()
336+
.to(self._device, non_blocking=True)
337337
)
338338

339339
def _build_value_network_optimizer(self) -> "Optimizer":
@@ -374,6 +374,8 @@ def _train_step(self) -> "Dict[str, Any]":
374374
result = super()._train_step()
375375
self.num_return_steps.step()
376376
result["num_return_steps"] = self.num_return_steps()
377+
result["entropy_coeff"] = result.pop("entropy_coeff", None)
378+
result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
377379
return result
378380

379381
# override
@@ -405,17 +407,7 @@ def _train_on_batch(
405407
self.num_return_steps(),
406408
)
407409
if self.normalize_advantage:
408-
length = mask.sum(dim=1, keepdim=True)
409-
advantages_mean = advantages.sum(dim=1, keepdim=True) / length
410-
advantages -= advantages_mean
411-
advantages *= mask
412-
advantages_stddev = (
413-
((advantages**2).sum(dim=1, keepdim=True) / length)
414-
.sqrt()
415-
.clamp(min=1e-6)
416-
)
417-
advantages /= advantages_stddev
418-
advantages *= mask
410+
normalize_(advantages, dim=-1, mask=mask)
419411

420412
# Discard next state value
421413
state_values = state_values[:, :-1]
@@ -449,10 +441,12 @@ def _train_on_batch_value_network(
449441
state_value = state_values[mask]
450442
target = targets[mask]
451443
loss = self._value_network_loss(state_value, target)
452-
loss *= is_weight[:, None].expand_as(mask)[mask]
444+
priority = None
445+
if self._buffer.is_prioritized:
446+
loss *= is_weight[:, None].expand_as(mask)[mask]
447+
priority = loss.detach().abs().to("cpu").numpy()
453448
loss = loss.mean()
454449
optimize_result = self._optimize_value_network(loss)
455-
priority = None
456450
result = {
457451
"state_value": state_value.mean().item(),
458452
"target": target.mean().item(),
@@ -490,7 +484,7 @@ def _get_default_value_network_preprocessor(
490484

491485

492486
class DistributedDataParallelA2C(DistributedDataParallelREINFORCE):
493-
"""Distributed data parallel Advantage Actor-Critic.
487+
"""Distributed data parallel Advantage Actor-Critic (A2C).
494488
495489
See Also
496490
--------

actorch/algorithms/acktr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
"""Actor-Critic Kronecker-Factored Trust Region."""
17+
"""Actor-Critic Kronecker-Factored Trust Region (ACKTR)."""
1818

1919
from typing import Any, Callable, Dict, Optional, Union
2020

@@ -43,7 +43,7 @@
4343

4444

4545
class ACKTR(A2C):
46-
"""Actor-Critic Kronecker-Factored Trust Region.
46+
"""Actor-Critic Kronecker-Factored Trust Region (ACKTR).
4747
4848
References
4949
----------
@@ -287,7 +287,7 @@ def _optimize_policy_network(self, loss: "Tensor") -> "Dict[str, Any]":
287287

288288

289289
class DistributedDataParallelACKTR(DistributedDataParallelA2C):
290-
"""Distributed data parallel Actor-Critic Kronecker-Factored Trust Region.
290+
"""Distributed data parallel Actor-Critic Kronecker-Factored Trust Region (ACKTR).
291291
292292
See Also
293293
--------

actorch/algorithms/algorithm.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from ray.tune.syncer import NodeSyncer
5252
from ray.tune.trial import ExportFormat
5353
from torch import Tensor
54-
from torch.cuda.amp import autocast
5554
from torch.distributions import Bernoulli, Categorical, Distribution, Normal
5655
from torch.profiler import profile, record_function, tensorboard_trace_handler
5756
from torch.utils.data import DataLoader
@@ -113,7 +112,7 @@ class Algorithm(ABC, Trainable):
113112

114113
_EXPORT_FORMATS = [ExportFormat.CHECKPOINT, ExportFormat.MODEL]
115114

116-
_UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True
115+
_OFF_POLICY = True
117116

118117
class Config(dict):
119118
"""Keyword arguments expected in the configuration received by `setup`."""
@@ -692,12 +691,7 @@ def _build_train_env(self) -> "BatchedEnv":
692691
if self.train_env_config is None:
693692
self.train_env_config = {}
694693

695-
try:
696-
train_env = self.train_env_builder(
697-
**self.train_env_config,
698-
)
699-
except TypeError:
700-
train_env = self.train_env_builder(**self.train_env_config)
694+
train_env = self.train_env_builder(**self.train_env_config)
701695
if not isinstance(train_env, BatchedEnv):
702696
train_env.close()
703697
train_env = SerialBatchedEnv(self.train_env_builder, self.train_env_config)
@@ -866,7 +860,7 @@ def _build_policy_network(self) -> "PolicyNetwork": # noqa: C901
866860
self.policy_network_postprocessors,
867861
)
868862
self._log_graph(policy_network.wrapped_model.model, "policy_network_model")
869-
return policy_network
863+
return policy_network.train().to(self._device, non_blocking=True)
870864

871865
def _build_train_agent(self) -> "Agent":
872866
if self.train_agent_builder is None:
@@ -989,14 +983,18 @@ def _build_dataloader(self) -> "DataLoader":
989983
if self.dataloader_builder is None:
990984
self.dataloader_builder = DataLoader
991985
if self.dataloader_config is None:
992-
fork = torch.multiprocessing.get_start_method() == "fork"
986+
use_mp = (
987+
self._OFF_POLICY
988+
and not self._buffer.is_prioritized
989+
and torch.multiprocessing.get_start_method() == "fork"
990+
)
993991
self.dataloader_config = {
994-
"num_workers": 1 if fork else 0,
992+
"num_workers": 1 if use_mp else 0,
995993
"pin_memory": True,
996994
"timeout": 0,
997995
"worker_init_fn": None,
998996
"generator": None,
999-
"prefetch_factor": 1 if fork else 2,
997+
"prefetch_factor": 1 if use_mp else 2,
1000998
"pin_memory_device": "",
1001999
}
10021000
if self.dataloader_config is None:
@@ -1037,20 +1035,15 @@ def _train_step(self) -> "Dict[str, Any]":
10371035
if self.train_num_episodes_per_iter:
10381036
train_num_episodes_per_iter = self.train_num_episodes_per_iter()
10391037
self.train_num_episodes_per_iter.step()
1040-
with (
1041-
autocast(**self.enable_amp)
1042-
if self.enable_amp["enabled"]
1043-
else contextlib.suppress()
1038+
for experience, done in self._train_sampler.sample(
1039+
train_num_timesteps_per_iter,
1040+
train_num_episodes_per_iter,
10441041
):
1045-
for experience, done in self._train_sampler.sample(
1046-
train_num_timesteps_per_iter,
1047-
train_num_episodes_per_iter,
1048-
):
1049-
self._buffer.add(experience, done)
1042+
self._buffer.add(experience, done)
10501043
result = self._train_sampler.stats
10511044
self._cumrewards += result["episode_cumreward"]
10521045

1053-
if not self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH:
1046+
if not self._OFF_POLICY:
10541047
for schedule in self._buffer_dataset.schedules.values():
10551048
schedule.step()
10561049
train_epoch_result = self._train_epoch()
@@ -1060,7 +1053,7 @@ def _train_step(self) -> "Dict[str, Any]":
10601053
schedule.step()
10611054
for schedule in self._buffer.schedules.values():
10621055
schedule.step()
1063-
if self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH:
1056+
if self._OFF_POLICY:
10641057
for schedule in self._buffer_dataset.schedules.values():
10651058
schedule.step()
10661059
return result
@@ -1113,16 +1106,11 @@ def _eval_step(self) -> "Dict[str, Any]":
11131106
eval_num_episodes_per_iter = self.eval_num_episodes_per_iter()
11141107
self.eval_num_episodes_per_iter.step()
11151108
self._eval_sampler.reset()
1116-
with (
1117-
autocast(**self.enable_amp)
1118-
if self.enable_amp["enabled"]
1119-
else contextlib.suppress()
1109+
for _ in self._eval_sampler.sample(
1110+
eval_num_timesteps_per_iter,
1111+
eval_num_episodes_per_iter,
11201112
):
1121-
for _ in self._eval_sampler.sample(
1122-
eval_num_timesteps_per_iter,
1123-
eval_num_episodes_per_iter,
1124-
):
1125-
pass
1113+
pass
11261114
for schedule in self._eval_agent.schedules.values():
11271115
schedule.step()
11281116
return self._eval_sampler.stats
@@ -1353,7 +1341,7 @@ def __init__(
13531341
Default to ``{}``.
13541342
placement_strategy:
13551343
The placement strategy
1356-
(see https://docs.ray.io/en/latest/ray-core/placement-group.html).
1344+
(see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0).
13571345
backend:
13581346
The backend for distributed execution
13591347
(see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group).

0 commit comments

Comments
 (0)