Finish PR #98: replay fallback checkpoint compatibility + RL test updates#166
Finish PR #98: replay fallback checkpoint compatibility + RL test updates#166amcberkes wants to merge 38 commits intogoogle:copybara_pushfrom
Conversation
amcberkes
commented
Apr 26, 2026
- Completes PR#98
- Incorporates replay fallback direction from Yukta/rl module finish #127
- Replaces Reverb-only checkpoint calls with ReplayBufferManager.save_checkpoint()
- Adds TFUniform fallback checkpoint save/restore support
- Updates RL tests for fallback compatibility
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive reinforcement learning (RL) framework, including scripts for agent training, evaluation, and configuration generation. It adds support for DDPG and SAC agents, implements custom observers for trajectory recording and visualization, and provides a robust replay buffer management system with Reverb and TFUniform fallbacks. Feedback focuses on critical logic errors such as redundant training step increments, potential resource leaks from uncleaned temporary directories in the evaluation script, and the presence of dead code and hardcoded configuration values that hinder maintainability.
| log_interval=args.log_interval, | ||
| checkpoint_interval=args.checkpoint_interval, | ||
| learner_iterations=args.learner_iterations, | ||
| train_step.assign_add(1) |
There was a problem hiding this comment.
Manually incrementing train_step here is redundant and likely incorrect. agent_learner.run(iterations=self.learner_iterations) already increments the train_step variable by the number of iterations performed. This double-incrementing will cause triggers (like evaluation and saving) to fire much more frequently than intended, as they are based on the global step count.
| self._num_timesteps_in_episode = self._environment.pyenv.envs[ | ||
| 0 | ||
| ]._num_timesteps_in_episode |
There was a problem hiding this comment.
| model_structure_dir = None | ||
| if os.path.exists(os.path.join(policy_dir, "greedy_policy")): | ||
| model_structure_dir = os.path.join(policy_dir, "greedy_policy") | ||
| logger.info("Using model structure from greedy_policy directory") | ||
| else: | ||
| raise ValueError( | ||
| "No policy structure directories found in" | ||
| f" {os.path.abspath(policy_dir)}" | ||
| ) |
| self.temp_saved_model_policy_dirpath = create_merged_saved_model( | ||
| self.saved_model_policy_dirpath | ||
| ) |
There was a problem hiding this comment.
The ExperimentEvaluator class creates a temporary directory via create_merged_saved_model but does not provide a mechanism to clean it up. This will lead to an accumulation of temporary directories on disk. Consider adding a cleanup method or implementing the class as a context manager to ensure shutil.rmtree is called on self.temp_saved_model_policy_dirpath.
| def old_main(argv: Sequence[str]): | ||
| if len(argv) > 1: | ||
| raise app.UsageError("Too many command-line arguments.") | ||
|
|
||
| # handle relative and absolute filepaths: | ||
| config_filepath = FLAGS.eval_config_filepath | ||
| if not os.path.isabs(config_filepath): | ||
| config_filepath = os.path.join(ROOT_DIR, config_filepath) | ||
|
|
||
| policy_dirpath = FLAGS.eval_policy_dirpath | ||
| if ( | ||
| policy_dirpath is not None | ||
| and not os.path.isabs(policy_dirpath) | ||
| and policy_dirpath != "schedule" | ||
| ): | ||
| policy_dirpath = os.path.join(ROOT_DIR, policy_dirpath) | ||
|
|
||
| evaluate_policy( | ||
| experiment_name=FLAGS.eval_experiment_name, | ||
| policy_dirpath=policy_dirpath, | ||
| config_filepath=config_filepath, | ||
| num_eval_episodes=FLAGS.num_eval_episodes, | ||
| ) |
| if self.agent_type not in ['sac', 'ddpg']: | ||
| raise ValueError( | ||
| 'Agent {self.agent_type} has not (yet) been implemented. Please' | ||
| " choose one of: ['sac', 'ddpg']." |
| data_spec=self.agent.collect_data_spec, | ||
| capacity=50000, # Use default capacity | ||
| checkpoint_dir=new_buffer_path, # Use the copied buffer path | ||
| sequence_length=2, | ||
| # should we keep these defaults, or use the dynamic parameter values? | ||
| ) |
|
I got the RL training running end-to-end on my Mac without dm-reverb. I switched replay handling to a TFUniform fallback path and added metadata persistence for replay capacity (replay_buffer_metadata.json), which fixed the restore mismatch I was hitting during train startup. |
|
@amcberkes awesome! Looks like there are two file conflicts that need resolution. Let me know once they are resolved and I will provide a review. |