Skip to content

Finish PR #98: replay fallback checkpoint compatibility + RL test updates#166

Open
amcberkes wants to merge 38 commits intogoogle:copybara_pushfrom
amcberkes:finish-pr98-reverb-tests
Open

Finish PR #98: replay fallback checkpoint compatibility + RL test updates#166
amcberkes wants to merge 38 commits intogoogle:copybara_pushfrom
amcberkes:finish-pr98-reverb-tests

Conversation

@amcberkes
Copy link
Copy Markdown
Contributor

  • Completes PR#98
  • Incorporates replay fallback direction from Yukta/rl module finish #127
  • Replaces Reverb-only checkpoint calls with ReplayBufferManager.save_checkpoint()
  • Adds TFUniform fallback checkpoint save/restore support
  • Updates RL tests for fallback compatibility

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive reinforcement learning (RL) framework, including scripts for agent training, evaluation, and configuration generation. It adds support for DDPG and SAC agents, implements custom observers for trajectory recording and visualization, and provides a robust replay buffer management system with Reverb and TFUniform fallbacks. Feedback focuses on critical logic errors such as redundant training step increments, potential resource leaks from uncleaned temporary directories in the evaluation script, and the presence of dead code and hardcoded configuration values that hinder maintainability.

log_interval=args.log_interval,
checkpoint_interval=args.checkpoint_interval,
learner_iterations=args.learner_iterations,
train_step.assign_add(1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Manually incrementing train_step here is redundant and likely incorrect. agent_learner.run(iterations=self.learner_iterations) already increments the train_step variable by the number of iterations performed. This double-incrementing will cause triggers (like evaluation and saving) to fire much more frequently than intended, as they are based on the global step count.

Comment on lines +43 to +45
self._num_timesteps_in_episode = self._environment.pyenv.envs[
0
]._num_timesteps_in_episode
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable self._num_timesteps_in_episode is initialized here using fragile internal access to the environment, but it is never used elsewhere in the class. This code should be removed to improve maintainability and avoid potential crashes if the environment structure changes.

Comment on lines +121 to +129
model_structure_dir = None
if os.path.exists(os.path.join(policy_dir, "greedy_policy")):
model_structure_dir = os.path.join(policy_dir, "greedy_policy")
logger.info("Using model structure from greedy_policy directory")
else:
raise ValueError(
"No policy structure directories found in"
f" {os.path.abspath(policy_dir)}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic currently only checks for the greedy_policy directory. As noted in the docstring, it should also support the policy directory as a fallback. If neither is found, then it should raise the ValueError.

Comment on lines +374 to +376
self.temp_saved_model_policy_dirpath = create_merged_saved_model(
self.saved_model_policy_dirpath
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ExperimentEvaluator class creates a temporary directory via create_merged_saved_model but does not provide a mechanism to clean it up. This will lead to an accumulation of temporary directories on disk. Consider adding a cleanup method or implementing the class as a context manager to ensure shutil.rmtree is called on self.temp_saved_model_policy_dirpath.

Comment on lines +522 to +544
def old_main(argv: Sequence[str]):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# handle relative and absolute filepaths:
config_filepath = FLAGS.eval_config_filepath
if not os.path.isabs(config_filepath):
config_filepath = os.path.join(ROOT_DIR, config_filepath)

policy_dirpath = FLAGS.eval_policy_dirpath
if (
policy_dirpath is not None
and not os.path.isabs(policy_dirpath)
and policy_dirpath != "schedule"
):
policy_dirpath = os.path.join(ROOT_DIR, policy_dirpath)

evaluate_policy(
experiment_name=FLAGS.eval_experiment_name,
policy_dirpath=policy_dirpath,
config_filepath=config_filepath,
num_eval_episodes=FLAGS.num_eval_episodes,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The old_main function appears to be dead code and should be removed to improve maintainability.

if self.agent_type not in ['sac', 'ddpg']:
raise ValueError(
'Agent {self.agent_type} has not (yet) been implemented. Please'
" choose one of: ['sac', 'ddpg']."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ValueError message is missing the f prefix, so {self.agent_type} will not be interpolated correctly.

Comment on lines +423 to +428
data_spec=self.agent.collect_data_spec,
capacity=50000, # Use default capacity
checkpoint_dir=new_buffer_path, # Use the copied buffer path
sequence_length=2,
# should we keep these defaults, or use the dynamic parameter values?
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ReplayBufferManager should be initialized using the class attributes self.buffer_capacity and self.sequence_length instead of hardcoded defaults, to ensure it respects the user-provided configuration.

@amcberkes
Copy link
Copy Markdown
Contributor Author

I got the RL training running end-to-end on my Mac without dm-reverb.

I switched replay handling to a TFUniform fallback path and added metadata persistence for replay capacity (replay_buffer_metadata.json), which fixed the restore mismatch I was hitting during train startup.

@s2t2
Copy link
Copy Markdown
Collaborator

s2t2 commented Apr 27, 2026

@amcberkes awesome!

Looks like there are two file conflicts that need resolution. Let me know once they are resolved and I will provide a review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants