diff --git a/README.md b/README.md index b5c4263..793330d 100755 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Choose an example below to get started. Each example includes step-by-step instr | **[VLM Multi-Turn Math](docs/vlm_geo3k_multiturn.md)** | geometry 3k math problem solving with tool calling | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/r39htm2o?nw=nwuserzhusq20) | | **[LLM Gomoku Agent](docs/gomoku_multiturn.md)** | A multi-turn gomoku agent | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/7a7ggkw3?nw=nwuserzhusq20) | | **[LLM AlfWorld Agent](docs/alfworld_multiturn.md)** | A multi-turn alfworld agent | [wandb](https://wandb.ai/1125027232/opentinker-public/runs/3jrlolk7?nw=nwuser1125027232) | +| **[LLM Android World Agent](docs/android_world_multiturn.md)** | A multi-turn android world agent | | ## šŸ“¦ Installation diff --git a/docs/android_world_multiturn.md b/docs/android_world_multiturn.md new file mode 100644 index 0000000..b8eeeb2 --- /dev/null +++ b/docs/android_world_multiturn.md @@ -0,0 +1,232 @@ +# LLM Game Agent (AndroidWorld Multi-Turn) + +This example demonstrates training a language model to complete tasks in the Android operating system environment using AndroidWorld. + +## Overview + +**AndroidWorld** is a dynamic benchmarking environment for autonomous agents to interact with the Android operating system. The agent perceives the screen via a list of UI elements and interacts by performing actions like clicking, typing, and scrolling. + +Tasks include: +- Adding contacts +- Managing settings +- Browsing information +- Sending messages +- And more... + +## Prerequisites + +1. Complete the [Installation](../README.md#-installation) steps. +2. **Environment Setup**: You must install the Android SDK and run an Emulator. See the **[Detailed Environment Setup](#detailed-environment-setup)** section below for instructions. +3. Get your IP address: `hostname -I` + +## Step 1: Start the Scheduler (Server Side) + +```bash +bash opentinker/scripts/launch_scheduler.sh --scheduler-port +``` + +## Step 2: Start the AndroidWorld Environment (Server Side) + +Before starting the environment server, ensure your Android Emulator is running (see setup below). + +```bash +python -m opentinker.environment.android_world.android_world_server \ + --port 8092 \ + --max_steps 50 \ + --split train +``` + +**Server Options:** + +- `--port`: Server port (default: 8082, recommend 8092 to match client config) +- `--max_steps`: Max steps per episode (default: 50) +- `--split`: Dataset split (`train`, `eval_in_distribution`, `eval_out_of_distribution`) +- `--shards`: Number of parallel server instances (for parallel training) + +## Step 3: Run Training + +```bash +python opentinker/client/android_world_rl.py \ + tokenizer_path=Qwen/Qwen2.5-3B-Instruct \ + batch_size=4 \ + val_batch_size=50 \ + num_steps=1000 \ + save_freq=20000 \ + test_freq=10 \ + scheduler_url=http://: \ + interaction.config.env_port=8092 \ + interaction.config.env_host= +``` + +**Training Parameters:** + +- `num_steps`: Total training steps (alternative: use `num_epochs`) +- `batch_size`: Training batch size +- `val_batch_size`: Validation samples per evaluation +- `test_freq`: Validation frequency (every N steps) +- `adv_estimator`: Advantage estimator (`gae`, `grpo`, `grpo_per_step`) + +## Reward Structure + +| Event | Reward | +| :--------------- | ------ | +| Task Success | +10.0 | +| Task Failure | -1.0 | +| Per Step Penalty | -0.01 | +| Invalid Action | -0.1 | + +## Example Actions + +The agent interacts with the environment by outputting JSON commands referencing UI element indices: + +- **Click**: `{"action_type": "click", "index": 4}` +- **Type**: `{"action_type": "input_text", "text": "Alice", "index": 2}` +- **Scroll**: `{"action_type": "scroll", "direction": "down"}` +- **Open App**: `{"action_type": "open_app", "app_name": "Settings"}` +- **Navigate Home**: `{"action_type": "navigate_home"}` +- **Navigate Back**: `{"action_type": "navigate_back"}` +- **Answer Question**: `{"action_type": "answer", "text": "It is 5 PM."}` +- **Finish Task**: `{"action_type": "status", "goal_status": "complete"}` + +## Configuration Reference + +See [`opentinker/client/client_config/android_world_param.yaml`](../opentinker/client/client_config/android_world_param.yaml) for full configuration options. + +--- + +## Detailed Environment Setup + +### 1. Android SDK & Command Line Tools + +If you do not have Android Studio installed, you can set up the command-line tools manually. + +1. **Create Directory Structure:** + ```bash + mkdir -p /usr/local/android-sdk/cmdline-tools + cd /usr/local/android-sdk/cmdline-tools + ``` + +2. **Download Command Line Tools:** + ```bash + wget https://dl.google.com/android/repository/commandlinetools-linux-11076708_latest.zip -O cmdline-tools.zip + unzip cmdline-tools.zip + mv cmdline-tools latest + rm cmdline-tools.zip + ``` + +3. **Install SDK Components:** + ```bash + export ANDROID_HOME=/usr/local/android-sdk + export PATH=$ANDROID_HOME/cmdline-tools/latest/bin:$PATH + + # Accept licenses + yes | sdkmanager --licenses --sdk_root=$ANDROID_HOME + + # Install Platform Tools (adb), Android 33 Platform, and Build Tools + sdkmanager "platform-tools" "platforms;android-33" "build-tools;34.0.0" "emulator" --sdk_root=$ANDROID_HOME + ``` + +4. **Configure Environment Variables:** + Add the following to your shell configuration file (`~/.bashrc` or `~/.zshrc`): + ```bash + export JAVA_HOME="/usr/local/android-studio/jbr" # Or your JDK path + export ANDROID_HOME="/usr/local/android-sdk" + export PATH="$JAVA_HOME/bin:$ANDROID_HOME/cmdline-tools/latest/bin:$ANDROID_HOME/platform-tools:$ANDROID_HOME/emulator:$PATH" + ``` + +### 2. Create Android Virtual Device (AVD) + +Create an AVD named `AndroidWorldAvd` targeting Android 13 (Tiramisu, API 33). + +1. **Install System Image:** + * For x86_64 (Standard PC): + ```bash + sdkmanager "system-images;android-33;google_apis;x86_64" --sdk_root=$ANDROID_HOME + ``` + * For ARM64 (Apple Silicon or Software Emulation on x86): + ```bash + sdkmanager "system-images;android-33;google_apis;arm64-v8a" --sdk_root=$ANDROID_HOME + ``` + +2. **Create AVD:** + ```bash + echo "no" | avdmanager create avd --name AndroidWorldAvd --package "system-images;android-33;google_apis;x86_64" --device "pixel_6" + ``` + *(Replace `x86_64` with `arm64-v8a` if applicable)* + +### 3. Launch Emulator + +Start the emulator in a separate terminal or background process using the `sg` command to ensure correct group permissions (e.g., `kvm`). + +* **Standard Launch (with GUI):** + ```bash + sg kvm -c "emulator -avd AndroidWorldAvd -no-snapshot -grpc 8554" + ``` + +* **Headless Launch (Server/Docker):** + ```bash + sg kvm -c "emulator -avd AndroidWorldAvd -no-snapshot -grpc 8554 -no-window -no-audio" + ``` + +* **Software Emulation (No KVM):** + If hardware acceleration is unavailable, add `-accel off`. **Warning: Performance will be very low.** + ```bash + emulator -avd AndroidWorldAvd -no-snapshot -grpc 8554 -no-window -no-audio -accel off + ``` + +## Quick Start with `run_android.sh` + +For multi-emulator parallel training, we provide an all-in-one launcher script [`opentinker/scripts/run_android.sh`](../opentinker/scripts/run_android.sh) that automates AVD creation, emulator startup, environment server, and training client. + +### Usage + +Run each step in a **separate terminal**: + +```bash +# Step 0 (one-time): Create N AVDs for parallel training +bash opentinker/scripts/run_android.sh setup-avds + +# Step 1: Start the scheduler +bash opentinker/scripts/run_android.sh scheduler + +# Step 2: Start N Android emulators in parallel +bash opentinker/scripts/run_android.sh simulator + +# Step 3: Start the sharded environment server (after emulators fully boot) +bash opentinker/scripts/run_android.sh env + +# Step 4: Launch RL training +bash opentinker/scripts/run_android.sh client +``` + +### Environment Variables + +All settings are configurable via environment variables: + +| Variable | Default | Description | +| :------- | :------ | :---------- | +| `NUM_EMULATORS` | `4` | Number of parallel emulators | +| `NUM_GPUS` | `4` | Number of GPUs for model parallelism | +| `GPUS` | `[0,1,2,3]` | GPU device list | +| `MODEL_PATH` | `Qwen/Qwen2.5-3B-Instruct` | Model path or HuggingFace ID | +| `AVD_NAME` | `AndroidWorldAvd` | AVD name prefix (creates `{AVD_NAME}_0`, `{AVD_NAME}_1`, ...) | +| `EMULATOR_HEADLESS` | `1` | Set `0` to show emulator GUI | +| `EMULATOR_NO_KVM` | `0` | Set `1` for software emulation (slow) | +| `SCHEDULER_PORT` | `9780` | Scheduler listen port | +| `ENV_PORT` | `9092` | Environment server base port | + +**Example** — scale to 8 emulators on 8 GPUs: + +```bash +NUM_EMULATORS=8 NUM_GPUS=8 GPUS="[0,1,2,3,4,5,6,7]" bash opentinker/scripts/run_android.sh setup-avds +# Then run scheduler / simulator / env / client with the same env vars +``` + +--- + +## Troubleshooting + +* **"KVM is not found"**: Ensure virtualization is enabled in your BIOS/Hypervisor. On Linux, check permissions for `/dev/kvm`. If in a container, run with `--device /dev/kvm`. +* **Emulator crashes immediately**: Check logs. If running x86_64 image on ARM or vice-versa, the emulator will fail. Use the correct system image for your host architecture. +* **"ADB command not found"**: Ensure `platform-tools` is in your `$PATH`. +* **"Process system isn't responding"**: Common in software emulation (`-accel off`). Wait for the system to stabilize or dismiss the dialog. \ No newline at end of file diff --git a/opentinker/backend_patch/verl/experimental/__init__.py b/opentinker/backend_patch/verl/experimental/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/opentinker/backend_patch/verl/experimental/agent_loop/__init__.py b/opentinker/backend_patch/verl/experimental/agent_loop/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/opentinker/backend_patch/verl/experimental/agent_loop/per_turn_agent_loop.py b/opentinker/backend_patch/verl/experimental/agent_loop/per_turn_agent_loop.py new file mode 100644 index 0000000..63aeb69 --- /dev/null +++ b/opentinker/backend_patch/verl/experimental/agent_loop/per_turn_agent_loop.py @@ -0,0 +1,457 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Per-Turn Agent Loop Worker and Manager. + +This module provides patched versions of AgentLoopWorkerBase and AgentLoopManager +that support expanding multi-turn rollout outputs into individual per-turn training +samples. This avoids context length issues from concatenating all turns into one +long sequence and aligns training context with inference context. + +When an agent loop (e.g., AndroidAgentLoop) stores per-turn data in +extra_fields['per_turn_outputs'], this worker expands each turn into +a separate training sample with its own prompt, response, mask, and reward. +""" + +import asyncio +import logging +import os +from typing import Any, Optional + +import hydra +import numpy as np +import ray +import torch +from omegaconf import DictConfig +from tensordict import TensorDict + +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopBase, + AgentLoopManager, + AgentLoopOutput, + AgentLoopWorkerBase, + AsyncLLMServerManager, + _DummyConfig, + _InternalAgentLoopOutput, + _agent_loop_registry, + get_trajectory_info, +) +from verl.protocol import DataProto +from verl.utils.model import compute_position_id_with_mask +from verl.utils.rollout_trace import rollout_trace_attr +from verl.utils.transferqueue_utils import tqbridge + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class PerTurnAgentLoopWorkerBase(AgentLoopWorkerBase): + """Agent loop worker that expands per-turn outputs into individual training samples. + + When the agent loop returns per-turn data in extra_fields['per_turn_outputs'], + each turn is converted into a separate training sample with its own prompt, + response, mask, and reward. For agent loops that don't produce per-turn outputs, + the behavior is identical to the base AgentLoopWorkerBase. + + This solves two problems: + 1. Context length: Concatenated multi-turn sequences quickly exceed the model's + context window, especially for environments with long observations. + 2. Train/inference mismatch: During rollout, AndroidAgentLoop generates each turn + using only [system + latest observation] as context. But the concatenated + training sequence exposes the model to ALL previous turns, creating a + distribution mismatch. Per-turn training aligns training and inference contexts. + """ + + async def _pad_single_output( + self, + output: AgentLoopOutput, + ) -> _InternalAgentLoopOutput: + """Pad and convert a single AgentLoopOutput to _InternalAgentLoopOutput. + + This extracts the padding/conversion logic from the parent's _run_agent_loop + so it can be reused for each per-turn output. + """ + prompt_length = self.config.actor_rollout_ref.rollout.prompt_length + response_length = self.config.actor_rollout_ref.rollout.response_length + + # Truncate prompt_ids if they exceed prompt_length (left-truncate to keep recent context) + prompt_ids = output.prompt_ids + if len(prompt_ids) > prompt_length: + logger.warning( + f"[PerTurnAgentLoop] Truncating per-turn prompt from {len(prompt_ids)} to {prompt_length} tokens" + ) + prompt_ids = prompt_ids[-prompt_length:] + + # Truncate response_ids if they exceed response_length + response_ids = output.response_ids[:response_length] + response_mask = output.response_mask[:response_length] + + # Left-pad prompt + self.tokenizer.padding_side = "left" + prompt_output = self.tokenizer.pad( + {"input_ids": prompt_ids}, + padding="max_length", + max_length=prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + if prompt_output["input_ids"].dim() == 1: + prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) + prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) + + # Right-pad response + self.tokenizer.padding_side = "right" + response_output = self.tokenizer.pad( + {"input_ids": response_ids}, + padding="max_length", + max_length=response_length, + return_tensors="pt", + return_attention_mask=True, + ) + if response_output["input_ids"].dim() == 1: + response_output["input_ids"] = response_output["input_ids"].unsqueeze(0) + response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0) + + # Pad response mask + response_mask_output = self.tokenizer.pad( + {"input_ids": response_mask}, + padding="max_length", + max_length=response_length, + return_tensors="pt", + return_attention_mask=False, + ) + if response_mask_output["input_ids"].dim() == 1: + response_mask_output["input_ids"] = response_mask_output["input_ids"].unsqueeze(0) + + # Pad logprobs + response_logprobs_tensor = None + if output.response_logprobs is not None: + logprobs = output.response_logprobs[:response_length] + pad_size = response_length - len(logprobs) + response_logprobs_tensor = torch.tensor(logprobs + [0.0] * pad_size).unsqueeze(0) + + response_mask_final = response_mask_output["input_ids"] * response_output["attention_mask"] + attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1) + input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1) + + # Handle multi-modal inputs and position_ids calculation + multi_modal_inputs = None + if ( + self.processor is not None + and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__ + ): + from verl.models.transformers.qwen2_vl import get_rope_index + + images = getattr(output, "multi_modal_data", {}).get("image", None) + current_text = self.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True) + multi_modal_inputs = self.processor(text=[current_text], images=images, return_tensors="pt") + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + multi_modal_inputs = dict(multi_modal_inputs) + + image_grid_thw = multi_modal_inputs.get("image_grid_thw") + video_grid_thw = multi_modal_inputs.get("video_grid_thw") + second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts") + + vision_position_ids = get_rope_index( + self.processor, + input_ids=input_ids.squeeze(0), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask.squeeze(0), + ).unsqueeze(0) + + valid_mask = attention_mask[0].bool() + text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long) + text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item()) + text_position_ids = text_position_ids.unsqueeze(0) + position_ids = torch.cat((text_position_ids, vision_position_ids), dim=1) + else: + position_ids = compute_position_id_with_mask(attention_mask) + + return _InternalAgentLoopOutput( + prompt_ids=prompt_output["input_ids"], + response_ids=response_output["input_ids"], + input_ids=input_ids, + position_ids=position_ids, + response_mask=response_mask_final, + attention_mask=attention_mask, + response_logprobs=response_logprobs_tensor, + multi_modal_inputs=multi_modal_inputs, + multi_modal_data=output.multi_modal_data, + reward_score=output.reward_score, + num_turns=output.num_turns, + metrics=output.metrics, + extra_fields=output.extra_fields, + ) + + async def _run_agent_loop( + self, + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + *, + agent_name: str, + **kwargs, + ) -> _InternalAgentLoopOutput | list[_InternalAgentLoopOutput]: + """Run agent loop and optionally expand per-turn outputs. + + If the agent loop produces per-turn outputs (via extra_fields['per_turn_outputs']), + each turn is converted into a separate _InternalAgentLoopOutput. Otherwise, + the behavior is identical to the base class. + + Returns: + Single _InternalAgentLoopOutput or list of _InternalAgentLoopOutput (one per turn). + """ + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=_DummyConfig(config=self.config), + server_manager=self.server_manager, + tokenizer=self.tokenizer, + processor=self.processor, + ) + output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) + + # Check for per-turn outputs from agent loops that support per-turn training + per_turn_outputs = output.extra_fields.pop('per_turn_outputs', None) + + if per_turn_outputs and len(per_turn_outputs) > 0: + # Expand per-turn outputs into separate training samples + results = [] + for turn_output in per_turn_outputs: + internal = await self._pad_single_output(turn_output) + results.append(internal) + + logger.info( + f"[PerTurnAgentLoop] Expanded 1 episode into {len(results)} per-turn training samples" + ) + return results + else: + # Standard single output processing (same as parent class) + enable_async_reward = ( + self.reward_router_address is not None and self.config.reward_model.enable_resource_pool + ) or not self.config.reward_model.enable + + internal = await self._pad_single_output(output) + + if output.reward_score is None and enable_async_reward: + batch = TensorDict( + { + "prompts": internal.prompt_ids, + "responses": internal.response_ids, + "attention_mask": internal.attention_mask, + "input_ids": internal.input_ids, + "position_ids": internal.position_ids, + }, + batch_size=1, + ) + non_tensor_batch = { + **{k: np.array([v]) for k, v in kwargs.items()}, + "__num_turns__": np.array([output.num_turns]), + "tool_extra_fields": np.array([output.extra_fields], dtype=object), + } + + data = DataProto( + batch=batch, + non_tensor_batch=non_tensor_batch, + ) + result = await self.reward_manager_worker.compute_score.remote(data) + output.reward_score = result["reward_score"] + output.extra_fields["reward_extra_info"] = result["reward_extra_info"] + internal.reward_score = output.reward_score + internal.extra_fields = output.extra_fields + + return internal + + @tqbridge() + async def generate_sequences(self, batch: DataProto) -> DataProto: + """Generate sequences with per-turn expansion support. + + This overrides the parent's generate_sequences to flatten per-turn outputs + from _run_agent_loop before passing them to _postprocess. When _run_agent_loop + returns a list (per-turn mode), the lists are flattened into a single list of + training samples. The effective batch size may increase (e.g., 4 episodes * 15 + turns/episode = 60 training samples). + + Args: + batch (DataProto): Input batch. + + Returns: + DataProto: Output batch with per-turn expanded samples. + """ + config = self.config.actor_rollout_ref.rollout + sampling_params = dict( + temperature=config.temperature, + top_p=config.top_p, + repetition_penalty=1.0, + logprobs=config.calculate_log_probs, + ) + + # Override sampling params for validation + if batch.meta_info.get("validate", False): + sampling_params["top_p"] = config.val_kwargs.top_p + sampling_params["temperature"] = config.val_kwargs.temperature + + # By default, assume single turn agent + if "agent_name" not in batch.non_tensor_batch: + default_agent_loop = config.agent.default_agent_loop + batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object) + + if "index" in batch.non_tensor_batch: + index = batch.non_tensor_batch["index"] + else: + index = np.arange(len(batch)) + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index.tolist(), batch.meta_info.get("validate", False) + ) + + tasks = [] + for i in range(len(batch)): + kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} + tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs))) + outputs = await asyncio.gather(*tasks) + + # Flatten per-turn outputs: _run_agent_loop may return a list (per-turn) or single output + # Also build an expansion index that maps each flat output back to its source episode. + # This allows the training server to expand the original batch (with reward fields etc.) + # to match the expanded batch size. + flat_outputs = [] + expansion_index = [] # expansion_index[j] = i means flat output j came from episode i + for i, o in enumerate(outputs): + if isinstance(o, list): + flat_outputs.extend(o) + expansion_index.extend([i] * len(o)) + else: + flat_outputs.append(o) + expansion_index.append(i) + + output = self._postprocess(flat_outputs) + + if len(flat_outputs) != len(outputs): + # Per-turn expansion happened: store the index so the training server + # can expand the original batch to match. + output.meta_info['per_turn_expansion_index'] = expansion_index + logger.info( + f"[PerTurnAgentLoop] Expanded {len(outputs)} episodes into {len(flat_outputs)} per-turn training samples" + ) + + return output + + +@ray.remote +class PerTurnAgentLoopWorker(PerTurnAgentLoopWorkerBase): + """Ray actor wrapper for PerTurnAgentLoopWorkerBase.""" + + def __init__( + self, + config: DictConfig, + server_handles: list[ray.actor.ActorHandle], + reward_router_address: str = None, + ): + super().__init__(config, server_handles, reward_router_address) + + +class PerTurnAgentLoopManager(AgentLoopManager): + """Agent loop manager that uses PerTurnAgentLoopWorker for per-turn training support. + + Drop-in replacement for AgentLoopManager. The only difference is that this manager + creates PerTurnAgentLoopWorker instances instead of AgentLoopWorker instances, + enabling per-turn output expansion when agent loops produce per-turn data. + + When per_turn_training is disabled in the config (or the agent loop doesn't produce + per-turn outputs), the behavior is identical to the standard AgentLoopManager. + """ + + def __init__(self, config: DictConfig, worker_group=None, rm_wg=None): + # Set the worker class BEFORE calling super().__init__(), which calls _init_agent_loop_workers + self.agent_loop_workers_class = PerTurnAgentLoopWorker + super().__init__(config, worker_group, rm_wg) + + def generate_sequences(self, prompts: DataProto) -> DataProto: + """Override to handle per-turn expansion indices across workers. + + Each worker may produce per_turn_expansion_index in meta_info, but + DataProto.concat() requires matching meta_info values. We extract the + per-worker expansion indices, adjust them to global offsets, remove + them from meta_info before concat, and attach the combined global + index to the final output. + """ + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.wake_up() + if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine: + self.reward_model_manager.wake_up() + + chunks = prompts.chunk(len(self.agent_loop_workers)) + outputs = ray.get( + [ + worker.generate_sequences.remote(chunk) + for worker, chunk in zip(self.agent_loop_workers, chunks, strict=True) + ] + ) + + # Collect per-turn expansion indices from each worker and build a + # global expansion index before concat (which would fail on conflicting + # meta_info values). + chunk_sizes = [len(c) for c in chunks] + has_expansion = any('per_turn_expansion_index' in o.meta_info for o in outputs) + global_expansion_index = None + + if has_expansion: + global_expansion_index = [] + offset = 0 + for o, cs in zip(outputs, chunk_sizes): + local_idx = o.meta_info.pop('per_turn_expansion_index', None) + if local_idx is not None: + global_expansion_index.extend([idx + offset for idx in local_idx]) + else: + # No expansion for this worker — identity mapping + global_expansion_index.extend(range(offset, offset + len(o))) + offset += cs + else: + for o in outputs: + o.meta_info.pop('per_turn_expansion_index', None) + + output = DataProto.concat(outputs) + + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.sleep() + if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine: + self.reward_model_manager.sleep() + + # Calculate performance metrics (same as parent) + metrics = [output.meta_info.pop("metrics") for output in outputs] + timing = self._performance_metrics(metrics, output) + + output.meta_info = {"timing": timing, **outputs[0].meta_info} + + if global_expansion_index is not None: + output.meta_info['per_turn_expansion_index'] = global_expansion_index + logger.info( + f"[PerTurnAgentLoop] Global expansion: {sum(chunk_sizes)} episodes -> " + f"{len(global_expansion_index)} per-turn training samples" + ) + + return output diff --git a/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py b/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py index 9f355ee..23c4a8f 100755 --- a/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py +++ b/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py @@ -982,10 +982,15 @@ def init_workers(self): # create async rollout manager and request scheduler self.async_rollout_mode = False if self.config.actor_rollout_ref.rollout.mode == "async": - from verl.experimental.agent_loop import AgentLoopManager + # Use PerTurnAgentLoopManager which supports expanding multi-turn rollouts + # into individual per-turn training samples (when per_turn_training=True). + # Falls back to standard behavior when per_turn_training is disabled. + from opentinker.backend_patch.verl.experimental.agent_loop.per_turn_agent_loop import ( + PerTurnAgentLoopManager, + ) self.async_rollout_mode = True - self.async_rollout_manager = AgentLoopManager( + self.async_rollout_manager = PerTurnAgentLoopManager( config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) diff --git a/opentinker/backend_patch/verl/workers/config/rollout.py b/opentinker/backend_patch/verl/workers/config/rollout.py index 54434aa..4928eb0 100755 --- a/opentinker/backend_patch/verl/workers/config/rollout.py +++ b/opentinker/backend_patch/verl/workers/config/rollout.py @@ -48,6 +48,8 @@ class MultiTurnConfig(BaseConfig): "max_tokens_per_turn", "weave_project", "experiment_name", + "per_turn_training", + "per_turn_reward_gamma", } enable: bool = False @@ -68,6 +70,17 @@ class MultiTurnConfig(BaseConfig): # Per-turn token limit (optional, None = no limit) max_tokens_per_turn: Optional[int] = None + # Per-turn training mode: each interaction turn becomes a separate training sample + # instead of concatenating all turns into one long sequence. This avoids context + # length limits and aligns training context with inference context (each turn only + # sees system + latest observation, matching the generation-time context). + per_turn_training: bool = False + + # Per-turn reward gamma: when > 0, use discounted cumulative returns instead of + # immediate rewards for per-turn training. Each turn's training reward becomes + # G_i = r_i + gamma * G_{i+1}, propagating final outcome signal backwards. + per_turn_reward_gamma: float = 0.0 + # Weave tracing (server-side) weave_project: Optional[str] = None experiment_name: Optional[str] = None diff --git a/opentinker/client/android_world_rl.py b/opentinker/client/android_world_rl.py new file mode 100644 index 0000000..122b02f --- /dev/null +++ b/opentinker/client/android_world_rl.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +"""AndroidWorld RL Training Client. + +This script trains an LLM agent to complete tasks in AndroidWorld. + +Usage: + # Start AndroidWorld server first (in another terminal): + python -m opentinker.environment.android_world.android_world_server --port 8082 + + # Run training: + python android_world_rl.py scheduler_url=http://localhost:8780 num_gpus=2 +""" + +from omegaconf import OmegaConf +import hydra + +from utils.http_training_client import ServiceClient, SchedulerClient +from opentinker.environment.base_game_environment import GameEnvironment +from opentinker.environment.android_world import AndroidWorldGame +from opentinker.environment.game_stats_client import GameStatsClient +from utils.utils import resolve_paths_in_config +from utils.scheduler_client_lifecycle import get_lifecycle_manager + + +@hydra.main(config_path="client_config", config_name="android_world_param.yaml") +def main(args): + # Resolve paths to support both absolute and relative paths + args = resolve_paths_in_config(args) + + # Get the lifecycle manager (this automatically enables cleanup handlers) + lifecycle = get_lifecycle_manager() + + # Initialize Weave tracing (optional) + enable_tracing = args.get("enable_tracing", False) + if enable_tracing: + try: + from opentinker.utils.rollout_trace_saver import init_weave_tracing + + weave_project = args.get("weave_project", "android-world-training") + init_weave_tracing( + project_name=weave_project, + experiment_name=args.experiment_name, + token2text=True, + ) + except Exception as e: + print(f"⚠ Failed to initialize Weave tracing: {e}") + + print("=" * 60) + print("Training with AndroidWorld Environment") + print("=" * 60) + + # 1. Connect to scheduler and submit job + scheduler_url = args.get("scheduler_url", "http://localhost:8780") + scheduler_api_key = args.get("scheduler_api_key", None) + + print(f"\nConnecting to scheduler at {scheduler_url}") + if scheduler_api_key: + print("āœ“ Using API key for authentication") + else: + print( + "⚠ No API key provided - authentication may fail if scheduler requires it" + ) + + scheduler_client = SchedulerClient( + scheduler_url=scheduler_url, api_key=scheduler_api_key + ) + + # Submit job with configuration + print("\nSubmitting training job to scheduler...") + job_result = scheduler_client.submit_job( + config=OmegaConf.to_container(args, resolve=True), + enable_agent_loop=True, # REQUIRED for GenericAgentLoop + wandb_key=args.get("wandb_key"), + num_gpus=args.get("num_gpus"), + ) + + job_id = job_result["job_id"] + server_url = job_result["server_url"] + + # Register job for automatic cleanup + lifecycle.register_job(scheduler_client, job_id) + + print(f"\nāœ“ Job {job_id} allocated!") + print(f" Server URL: {server_url}") + print(f" GPUs: {job_result.get('gpu_ids')}") + print(f" Port: {job_result.get('port')}") + print("=" * 60) + + # 2. Setup GameEnvironment with AndroidWorldGame + interaction_config = args.interaction.config + game_kwargs = { + "max_steps": interaction_config.get("max_total_steps", 50), + "split": interaction_config.get("split", "train"), + } + + env_endpoint = interaction_config.env_endpoint + + print("\nSetting up GameEnvironment with AndroidWorldGame...") + print(f" Environment endpoint: {env_endpoint}") + print(f" Max steps: {game_kwargs['max_steps']}") + print(f" Split: {game_kwargs['split']}") + print(f" Job ID for stats: {job_id}") + + env = GameEnvironment( + game_class=AndroidWorldGame, + config=args, + game_kwargs=game_kwargs, + job_id=job_id, # Pass job_id directly + ) + + print("āœ“ Environment created") + print(f" Interaction config path: {env.get_interaction_config_path()}") + + # 3. Setup GameStatsClient for per-step metrics + game_stats = GameStatsClient(env_endpoint, job_id=env.job_id) + if game_stats.health_check(): + print(f"āœ“ Connected to AndroidWorld server for metrics at {env_endpoint}") + game_stats.reset_all() # Reset all stats before training + else: + print(f"⚠ AndroidWorld server at {env_endpoint} not responding - metrics disabled") + game_stats = None + + # 4. Connect to allocated server + print(f"\nConnecting to allocated server at {server_url}") + client = ServiceClient( + server_url=server_url, + project_name=args.project_name, + experiment_name=args.experiment_name, + logger_backends=args.logger_backends, + ) + + # Set configuration on server + client.set_config(args, env) + + # 5. Train with game stats tracking + num_steps = args.get("num_steps", None) + num_epochs = args.get("num_epochs", None) + + if num_steps: + print(f"\nStarting training for {num_steps} steps...") + elif num_epochs: + print(f"\nStarting training for {num_epochs} epochs...") + else: + print("\nStarting training (1 epoch default)...") + + print(f"Checkpoint save frequency: {args.save_freq}") + print(f"Validation frequency: {args.test_freq}") + print("=" * 60) + + try: + # Train with game stats tracking + final_metrics = client.fit( + env=env, + num_epochs=num_epochs, + num_steps=num_steps, + save_freq=args.save_freq, + test_freq=args.test_freq, + verbose=True, + validate_before_training=True, + game_stats_client=game_stats, + ) + + print("\n" + "=" * 60) + print("Training completed!") + print(f"Final training metrics: {final_metrics}") + + # Display final cumulative game stats + if game_stats: + print("\n" + "-" * 40) + print("Final Game Statistics:") + cumulative = game_stats.get_all_stats() + if cumulative: + print(f" Total episodes: {cumulative.get('total_games', 0):.0f}") + print(f" Success rate: {cumulative.get('cumulative_win_rate', 0):.1%}") + print(f" Total successes: {cumulative.get('total_wins', 0):.0f}") + print(f" Total failures: {cumulative.get('total_losses', 0):.0f}") + print("=" * 60) + + finally: + # Clean up temporary files + env.cleanup() + + +if __name__ == "__main__": + main() diff --git a/opentinker/client/client_config/android_world_param.yaml b/opentinker/client/client_config/android_world_param.yaml new file mode 100644 index 0000000..7d7cb7a --- /dev/null +++ b/opentinker/client/client_config/android_world_param.yaml @@ -0,0 +1,89 @@ +# AndroidWorld Training Configuration +# Use with: python android_world_rl.py + +# Project settings +project_name: opentinker +experiment_name: android_world_ppo_training + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 4 +num_workers: 4 +# Training duration - set ONE of these (num_steps takes precedence if both set) +num_epochs: null # Number of epochs (null = use num_steps) +num_steps: 1000 # Total training steps (null = use num_epochs) +save_freq: 20000 +test_freq: 10 # Validation frequency (every N steps) + +# Validation parameters +val_batch_size: 4 # Total validation samples (matches batch_size for 4 emulators) + +# Model parameters +# Generation parameters +temperature: 1 # Lower temperature for more focused responses +top_p: 1 +max_new_tokens: 8192 # TOTAL response budget for entire multi-turn trajectory (NOT per-turn!) +max_prompt_tokens: 4096 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# RL Algorithm settings (passed to server via scheduler) +# adv_estimator options: +# - "grpo" : Standard GRPO (outcome-only advantage, requires rollout_n > 1) +# - "grpo_per_step" : Per-step GRPO with return-based advantages (for multi-turn tasks) +# - "gae" : Generalized Advantage Estimation (for PPO, works with rollout_n = 1) +# NOTE: For AndroidWorld with single emulator, use "gae" (PPO) since GRPO requires +# multiple rollouts which would conflict on the single emulator. +adv_estimator: "gae" +# rollout_n: For PPO (gae), rollout_n is always 1 (forced by server) +rollout_n: 1 + +# Server-side agent loop workers (should match number of emulators) +agent_num_workers: 4 + +# Interaction configuration +interaction: + name: android_world + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + # Number of env server shards (should match --shards N and number of emulators) + env_shards: 4 + # Bind each worker to a specific endpoint (1-to-1 worker <-> endpoint) + # Requires agent_num_workers == env_shards. Each worker gets its own emulator. + bind_worker_to_endpoint: true + max_steps: 20 # AndroidWorld episodes max steps + max_total_steps: 20 # Max environment step calls (controls rollout turns) + observation_template: "{observation}" + # AndroidWorld specific settings + split: train # train, eval_in_distribution, eval_out_of_distribution + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 # Per-turn response limit (optional, null for no limit) + # Weave tracing (optional - runs on SERVER side) + weave_project: null + experiment_name: "android_world_interaction" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py old mode 100755 new mode 100644 index 58c5fd4..08f65e9 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -638,6 +638,19 @@ def set_config(self, args: DictConfig, env=None): f"[ServiceClient] Passing multi_turn config to server: {multi_turn_cfg}" ) + # Override agent num_workers if specified (important for single-resource envs like AndroidWorld) + agent_num_workers = getattr(args, "agent_num_workers", None) + if agent_num_workers is not None: + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create( + {"actor_rollout_ref": {"rollout": {"agent": {"num_workers": agent_num_workers}}}} + ), + ) + print( + f"[ServiceClient] Overriding agent num_workers to: {agent_num_workers}" + ) + generation_config = { "temperature": args.temperature, "top_p": args.top_p, diff --git a/opentinker/environment/android_world/__init__.py b/opentinker/environment/android_world/__init__.py new file mode 100644 index 0000000..9e04be1 --- /dev/null +++ b/opentinker/environment/android_world/__init__.py @@ -0,0 +1 @@ +from .android_world_game import AndroidWorldGame diff --git a/opentinker/environment/android_world/android_world_game.py b/opentinker/environment/android_world/android_world_game.py new file mode 100644 index 0000000..51b30d4 --- /dev/null +++ b/opentinker/environment/android_world/android_world_game.py @@ -0,0 +1,494 @@ +import os +import sys +import random +import re +import threading +import json +import ast +from typing import Any, Dict, List, Optional +import logging + +from opentinker.environment.base_game import AbstractGame, StepResult +from opentinker.environment.android_world import prompts + +logger = logging.getLogger(__name__) + +# Ensure android_world is in path if it's in the current directory +if os.path.exists("android_world"): + sys.path.append(os.path.abspath("android_world")) + +# Android emulator connection defaults (overridable via env vars) +os.environ.setdefault("ADB_PATH", "adb") +os.environ.setdefault("ANDROID_CONSOLE_PORT", "5556") +os.environ.setdefault("ANDROID_GRPC_PORT", "8554") + +# AndroidWorld imports +try: + import android_world + from android_world import registry + from android_world.env import env_launcher + from android_world.env import interface + from android_world.env import json_action + from android_world.env import representation_utils + from android_world.agents import m3a_utils + ANDROID_WORLD_AVAILABLE = True +except ImportError: + ANDROID_WORLD_AVAILABLE = False + logger.warning("android_world not installed or not found.") + + +class AndroidWorldGame(AbstractGame): + """AndroidWorld environment game implementation. + + This implementation wraps the AndroidWorld environment for LLM RL training. + It adopts the T3A agent's prompting and observation style. + + NOTE: All game instances share ONE emulator. A global lock serializes all + reset() and step() operations to prevent concurrent access conflicts. + """ + + # Reward constants + REWARD_SUCCESS = 10.0 + REWARD_FAILURE = -1.0 + REWARD_STEP = -0.01 + REWARD_INVALID_ACTION = -0.1 + + # Limits + DEFAULT_MAX_STEPS = 20 + + # Task types in AndroidWorld + ALL_TASK_TYPES = [ + "ContactsAddContact", + ] + + _shared_envs: dict = {} + _use_shared_env: bool = False + + _cached_game_paths: Dict[str, List[str]] = {} + _cache_lock = threading.Lock() + + # NOTE: No emulator lock needed - with PPO (rollout_n=1), operations are + # naturally serialized. The env server's asyncio lock handles any remaining + # concurrency at the HTTP level. + + def __init__( + self, + config_path: Optional[str] = None, + max_steps: int = DEFAULT_MAX_STEPS, + task_types: Optional[List[str]] = None, + split: str = "train", + num_games: int = -1, + use_shared_env: bool = False, + emulator_console_port: Optional[int] = None, + emulator_grpc_port: Optional[int] = None, + ): + """Initialize AndroidWorld game. + + Args: + emulator_console_port: Emulator console port (default: from ANDROID_CONSOLE_PORT env or 5556) + emulator_grpc_port: Emulator gRPC port (default: from ANDROID_GRPC_PORT env or 8554) + """ + self.config_path = config_path + self.max_steps = max_steps + self.task_types = task_types or self.ALL_TASK_TYPES + self.split = split + self.num_games = num_games + self._use_shared_env = use_shared_env + + # Emulator ports (can be set per-shard for multi-emulator support) + self._emulator_console_port = emulator_console_port + self._emulator_grpc_port = emulator_grpc_port + + # Game state + self._env: Optional[interface.AsyncEnv] = None + self._task = None + self._current_obs = None + self._step_count = 0 + self._task_desc = "" + self._done = False + self._initialized = False + self._history = [] # List of step summaries for T3A prompt + + def _init_env(self): + """Initialize AndroidWorld environment.""" + if self._initialized: + return + + if ANDROID_WORLD_AVAILABLE: + adb_path = os.environ.get("ADB_PATH", None) + # Use instance port if set, otherwise fall back to env var + console_port = self._emulator_console_port or int(os.environ.get("ANDROID_CONSOLE_PORT", 5556)) + grpc_port = self._emulator_grpc_port or int(os.environ.get("ANDROID_GRPC_PORT", 8554)) + + if not adb_path: + try: + from android_world.env import android_world_controller + adb_path = android_world_controller.DEFAULT_ADB_PATH + except ImportError: + adb_path = "adb" + + logger.info(f"Initializing AndroidEnv (console_port={console_port}, grpc_port={grpc_port})") + try: + self._env = env_launcher.load_and_setup_env( + console_port=console_port, + grpc_port=grpc_port, + adb_path=adb_path, + emulator_setup=False, + freeze_datetime=True + ) + except Exception as e: + logger.error(f"Failed to initialize AndroidWorld env: {e}") + logger.warning("Falling back to MOCK mode.") + self._env = None + else: + logger.info("Running in MOCK mode (android_world not installed).") + + self._initialized = True + + def reset( + self, + task_type: Optional[str] = None, + seed: Optional[int] = None, + **kwargs + ) -> str: + """Reset the game to a new episode.""" + self._init_env() + + if seed is not None: + random.seed(seed) + + self._step_count = 0 + self._done = False + self._history = [] # Reset history + + if ANDROID_WORLD_AVAILABLE and self._env: + task_name = task_type + if not task_name: + task_name = random.choice(self.task_types) + + try: + task_registry = registry.TaskRegistry() + all_tasks = task_registry.get_registry(registry.TaskRegistry.ANDROID_WORLD_FAMILY) + + if task_name not in all_tasks: + logger.warning(f"Task {task_name} not found in registry. Using default.") + task_class = all_tasks.get("ContactsAddContact") + else: + task_class = all_tasks[task_name] + + params = task_class.generate_random_params() + if seed is not None: + params['seed'] = seed + + self._task = task_class(params) + logger.info(f"Resetting task: {self._task.name}") + + self._task.initialize_task(self._env) + self._env.hide_automation_ui() + state = self._env.reset(go_home=True) + + self._task_desc = self._task.goal + self._process_state(state) + + except Exception as e: + logger.error(f"Failed to reset AndroidWorld task {task_name}: {e}", exc_info=True) + self._set_mock_state(task_name) + else: + task_name = task_type or random.choice(self.task_types) + self._set_mock_state(task_name) + + return self._format_observation(include_instructions=True) + + def _set_mock_state(self, task_name: str): + """Set mock state for testing without the library.""" + self._task_desc = f"Mock Task: {task_name}. Interact with the device." + self._current_obs = "UI element 0: Button 'Settings' at [100, 200]\nUI element 1: Icon 'Chrome' at [300, 200]" + self._done = False + + def _validate_ui_element(self, ui_element: Any, screen_size: tuple[int, int]) -> bool: + """Filter out invalid UI elements (invisible or invalid bbox).""" + screen_width, screen_height = screen_size + if not ui_element.is_visible: + return False + if ui_element.bbox_pixels: + x_min = ui_element.bbox_pixels.x_min + x_max = ui_element.bbox_pixels.x_max + y_min = ui_element.bbox_pixels.y_min + y_max = ui_element.bbox_pixels.y_max + if (x_min >= x_max or x_min >= screen_width or x_max <= 0 or + y_min >= y_max or y_min >= screen_height or y_max <= 0): + return False + return True + + def _process_state(self, state: Any): + """Process AndroidWorld State into text representation.""" + if hasattr(state, "ui_elements") and self._env: + logical_screen_size = self._env.logical_screen_size + elements_text = "" + valid_count = 0 + for index, ui_element in enumerate(state.ui_elements): + if self._validate_ui_element(ui_element, logical_screen_size): + # Use str(ui_element) which typically provides a good summary + elements_text += f"UI element {index}: {str(ui_element)}\n" + valid_count += 1 + + logger.debug(f"_process_state: {len(state.ui_elements)} total elements, {valid_count} valid") + self._current_obs = elements_text if elements_text else "No visible interactable elements." + else: + self._current_obs = "Screen updated (No UI elements info)" + + def _format_observation(self, obs: Optional[str] = None, include_instructions: bool = False) -> str: + """Format observation using prompt template for agent loop. + + Args: + obs: Optional observation string. If None, uses self._current_obs. + include_instructions: If True, include PROMPT_PREFIX and GUIDANCE (for initial prompt). + If False, use simplified template (for subsequent turns). + """ + if obs is None: + obs = self._current_obs + + if include_instructions or self._step_count == 0: + # Full template with PROMPT_PREFIX and GUIDANCE for initial prompt + history_str = '\n'.join(self._history) if self._history else 'You just started, no action has been performed yet.' + + # We put everything in the user message/observation for the agent + return prompts.INITIAL_ACTION_SELECTION_PROMPT_TEMPLATE.format( + goal=self._task_desc, + history=history_str, + ui_elements_description=obs, + additional_guidelines="" + ) + else: + # Simplified template: only goal and UI elements description for subsequent turns + return prompts.ACTION_SELECTION_PROMPT_TEMPLATE.format( + goal=self._task_desc, + ui_elements_description=obs, + ) + + def step(self, action: str) -> StepResult: + """Execute an action in the environment.""" + if self._done: + return StepResult(observation="Episode finished.", reward=0.0, done=True, info={}) + + self._step_count += 1 + + # Parse Reason/Action + reason, json_str = self._parse_reason_action(action) + + reward = self.REWARD_STEP + step_summary = "Action failed." + + if ANDROID_WORLD_AVAILABLE and self._env and self._task: + if json_str: + try: + env_action = self._parse_json_to_env_action(json_str) + + if env_action: + # Handle special 'status' action for completion + if env_action.action_type == 'status': + if env_action.goal_status == 'complete': + reward = self.REWARD_SUCCESS + self._done = True + if self._task.is_successful(self._env) > 0.0: + step_summary = "Agent finished. Task Successful." + else: + reward = self.REWARD_FAILURE + step_summary = "Agent finished. Task Failed (Condition not met)." + else: + self._done = True + reward = self.REWARD_FAILURE + step_summary = "Agent declared task infeasible." + + elif env_action.action_type == 'answer': + step_summary = f"Agent answered: {env_action.text}" + + else: + # Execute physical action + self._env.execute_action(env_action) + state = self._env.get_state(wait_to_stabilize=True) + self._process_state(state) + reason_str = reason[:50] if reason else "" + step_summary = f"Executed {env_action.action_type}. {reason_str}..." + else: + reward = self.REWARD_INVALID_ACTION + step_summary = "Invalid JSON action format." + except Exception as e: + logger.error(f"Error executing action: {e}", exc_info=True) + reward = self.REWARD_INVALID_ACTION + step_summary = f"Execution error: {str(e)}" + else: + reward = self.REWARD_INVALID_ACTION + step_summary = "Could not parse Action JSON." + else: + # Mock mode + step_summary = f"Mock execute: {action}" + self._current_obs = "Screen updated." + + # Update History + self._history.append(f"Step {self._step_count}: {step_summary}") + + if self._step_count >= self.max_steps and not self._done: + self._done = True + reward = self.REWARD_FAILURE + step_summary += " (Timeout)" + + return StepResult( + observation=self._format_observation(), + reward=reward, + done=self._done, + info={ + "raw_reward": float(reward), + "action_taken": action, + "task": self._task_desc, + }, + ) + + def _parse_reason_action(self, text: str) -> tuple[Optional[str], Optional[str]]: + """Parse 'Reason: ... Action: {...}' output.""" + reason_match = re.search(r'Reason:(.*)Action:', text, flags=re.DOTALL) + reason = reason_match.group(1).strip() if reason_match else None + + action_match = re.search(r'Action:(.*)', text, flags=re.DOTALL) + action_json = action_match.group(1).strip() if action_match else None + + # If strict format fails, try to just find the last JSON blob + if not action_json: + json_candidates = re.findall(r'\{.*?\}', text, re.DOTALL) + if json_candidates: + action_json = json_candidates[-1] + + return reason, action_json + + def _parse_json_to_env_action(self, json_str: str) -> Any: + """Convert JSON string to AndroidWorld JSONAction.""" + # Clean up json_str - extract just the first dict/JSON object + # The string may contain extra text like "Your Answer:" after the JSON + + # First, try to find a complete JSON object with double quotes + json_match = re.search(r'\{[^{}]*"action_type"[^{}]*\}', json_str) + if json_match: + try: + action_dict = json.loads(json_match.group()) + if isinstance(action_dict, dict): + return json_action.JSONAction(**action_dict) + except (json.JSONDecodeError, TypeError): + pass + + # Try to find a complete dict with single quotes (Python style) + dict_match = re.search(r"\{[^{}]*'action_type'[^{}]*\}", json_str) + if dict_match: + try: + action_dict = ast.literal_eval(dict_match.group()) + if isinstance(action_dict, dict): + return json_action.JSONAction(**action_dict) + except (ValueError, SyntaxError, TypeError): + pass + + # Fallback: try parsing the whole string + try: + action_dict = json.loads(json_str) + if isinstance(action_dict, dict): + return json_action.JSONAction(**action_dict) + except json.JSONDecodeError: + try: + action_dict = ast.literal_eval(json_str) + if isinstance(action_dict, dict): + return json_action.JSONAction(**action_dict) + except: + pass + + return None + + def get_system_prompt(self) -> str: + """Return the system prompt.""" + # The T3A prompt template in get_user_message_with_state contains the instructions. + # We can return a generic system prompt here. + return "You are a helpful AI assistant capable of operating an Android device." + + def get_initial_user_message(self) -> str: + """Return the initial user message.""" + # Used if environment is already reset + return self._format_observation(include_instructions=True) + + def get_state(self) -> Dict[str, Any]: + """Return the current game state.""" + return { + "observation": self._current_obs, + "task": self._task_desc, + "step_count": self._step_count, + "max_steps": self.max_steps, + "done": self._done, + } + + def generate_initial_state(self) -> Dict[str, Any]: + """Generate a random initial state for training data.""" + task_type = random.choice(self.task_types) + return { + "task_type": task_type, + "seed": random.randint(0, 1000000), + } + + def get_user_message_with_state( + self, task_type: Optional[str] = None, **kwargs + ) -> str: + """Generate a user message with the rendered initial state for the prompt. + + NOTE: This is called by Client (DataLoader) to generate the initial prompt. + We do NOT call reset() here because: + 1. Server will call reset() again via start_interaction() -> /reset + 2. Both would operate the same Android emulator, causing conflicts + + Instead, we return a properly formatted prompt with placeholder UI elements. + The real reset() and UI elements come from Server's /reset call. + """ + task_name = task_type or "ContactsAddContact" + seed = kwargs.get("seed") + + # Generate task goal based on task type (same logic as in reset) + # This ensures consistency between client prompt and server reset + if ANDROID_WORLD_AVAILABLE: + try: + task_registry = registry.TaskRegistry() + all_tasks = task_registry.get_registry(registry.TaskRegistry.ANDROID_WORLD_FAMILY) + if task_name in all_tasks: + task_class = all_tasks[task_name] + params = task_class.generate_random_params() + if seed is not None: + params['seed'] = seed + temp_task = task_class(params) + goal = temp_task.goal + else: + goal = f"Complete the {task_name} task." + except Exception: + goal = f"Complete the {task_name} task." + else: + goal = f"Complete the {task_name} task." + + # Return formatted prompt with real Home Screen UI elements + # All AndroidWorld tasks start from Home Screen (go_home=True in reset) + # This UI is consistent across all tasks, so we can hardcode it + home_screen_ui = ( + "UI element 0: UIElement(text=None, content_description=None, class_name='android.widget.ScrollView', is_clickable=False, is_scrollable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 1: UIElement(text=None, content_description='Home', class_name='android.view.View', is_clickable=False, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 2: UIElement(text='Phone', content_description='Phone', class_name='android.widget.TextView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 3: UIElement(text='Messages', content_description='Messages', class_name='android.widget.TextView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 4: UIElement(text='Chrome', content_description='Chrome', class_name='android.widget.TextView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 5: UIElement(text='Gmail', content_description='Gmail', class_name='android.widget.TextView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 6: UIElement(text=None, content_description='Search', class_name='android.widget.FrameLayout', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 7: UIElement(text='Photos', content_description='Photos', class_name='android.widget.TextView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 8: UIElement(text='YouTube', content_description='YouTube', class_name='android.widget.TextView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + "UI element 9: UIElement(text=None, content_description='Google app', class_name='android.widget.ImageView', is_clickable=True, is_visible=True, package_name='com.google.android.apps.nexuslauncher')\n" + ) + history = "You just started, no action has been performed yet." + + return prompts.INITIAL_ACTION_SELECTION_PROMPT_TEMPLATE.format( + goal=goal, + history=history, + ui_elements_description=home_screen_ui, + additional_guidelines="" + ) + + def get_interaction_name(self) -> str: + """Return the interaction name.""" + return "android_world" diff --git a/opentinker/environment/android_world/android_world_server.py b/opentinker/environment/android_world/android_world_server.py new file mode 100644 index 0000000..54fa114 --- /dev/null +++ b/opentinker/environment/android_world/android_world_server.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +"""AndroidWorld Environment Server. + +This script starts an AndroidWorld game server using the generic base_game_server. + +Usage: + python android_world_server.py + python android_world_server.py --port 8082 --max_steps 50 + python android_world_server.py --port 8091 --shards 8 +""" + +import os +import argparse +import subprocess +import sys +import time + +# Disable GPU if not needed, similar to ALFWorld +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +def main(): + parser = argparse.ArgumentParser(description="AndroidWorld Game Server") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8082, help="Server port") + parser.add_argument( + "--shards", + type=int, + default=8, + help="Number of independent server processes to launch on consecutive ports.", + ) + parser.add_argument( + "--config_path", type=str, default=None, help="Path to config file" + ) + parser.add_argument( + "--max_steps", type=int, default=50, help="Max steps per episode" + ) + parser.add_argument( + "--split", + type=str, + default="train", + choices=["train", "eval_in_distribution", "eval_out_of_distribution"], + help="Dataset split to use", + ) + parser.add_argument( + "--num_games", + type=int, + default=-1, + help="Number of games to load", + ) + # Multi-emulator support: each shard connects to a different emulator + parser.add_argument( + "--emulator_base_console_port", + type=int, + default=5556, + help="Base console port for emulators. Shard i uses port base+i*2 (e.g., 5556, 5558, 5560, 5562)", + ) + parser.add_argument( + "--emulator_base_grpc_port", + type=int, + default=8554, + help="Base gRPC port for emulators. Shard i uses port base+i (e.g., 8554, 8555, 8556, 8557)", + ) + # Per-shard emulator ports (set automatically when launching shards) + parser.add_argument("--emulator_console_port", type=int, default=None, help="(Internal) Console port for this shard") + parser.add_argument("--emulator_grpc_port", type=int, default=None, help="(Internal) gRPC port for this shard") + args = parser.parse_args() + + # Import here to avoid issues with multiprocessing + from opentinker.environment.android_world.android_world_game import AndroidWorldGame + + print("\nAndroidWorld Game Configuration:") + print(f" Max steps: {args.max_steps}") + print(f" Split: {args.split}") + print(f" Num games: {args.num_games if args.num_games > 0 else 'all'}") + print(f" Shards: {args.shards}") + print(f" Config: {args.config_path or 'default'}") + if args.shards > 1: + print(f" Emulator base ports: console={args.emulator_base_console_port}, grpc={args.emulator_base_grpc_port}") + + if args.shards and args.shards > 1: + print( + f"\nStarting sharded mode: {args.shards} shards on ports {args.port}..{args.port + args.shards - 1}" + ) + print("Each shard connects to a different emulator:") + for i in range(args.shards): + console_port = args.emulator_base_console_port + i * 2 # 5556, 5558, 5560, 5562 + grpc_port = args.emulator_base_grpc_port + i # 8554, 8555, 8556, 8557 + print(f" Shard {i}: server port {args.port + i}, emulator console={console_port}, grpc={grpc_port}") + + children: list[subprocess.Popen] = [] + try: + for i in range(args.shards): + port_i = args.port + i + console_port_i = args.emulator_base_console_port + i * 2 + grpc_port_i = args.emulator_base_grpc_port + i + cmd = [ + sys.executable, + os.path.abspath(__file__), + "--host", + args.host, + "--port", + str(port_i), + "--shards", + "1", + "--max_steps", + str(args.max_steps), + "--split", + args.split, + "--num_games", + str(args.num_games), + "--emulator_console_port", + str(console_port_i), + "--emulator_grpc_port", + str(grpc_port_i), + ] + if args.config_path is not None: + cmd.extend(["--config_path", args.config_path]) + + children.append(subprocess.Popen(cmd)) + time.sleep(0.2) + + print("Shards started. Press Ctrl+C to stop all shards.") + while True: + for p in children: + rc = p.poll() + if rc is not None: + raise RuntimeError( + f"Shard process exited early with code {rc}: pid={p.pid}" + ) + time.sleep(1.0) + except KeyboardInterrupt: + pass + finally: + for p in children: + try: + p.terminate() + except Exception: + pass + for p in children: + try: + p.wait(timeout=5) + except Exception: + try: + p.kill() + except Exception: + pass + return + + from opentinker.environment.base_game_server import run_game_server + + # For single shard mode, use explicit ports if provided, else fall back to base ports + console_port = args.emulator_console_port or args.emulator_base_console_port + grpc_port = args.emulator_grpc_port or args.emulator_base_grpc_port + print(f" Emulator: console={console_port}, grpc={grpc_port}") + + run_game_server( + game_class=AndroidWorldGame, + host=args.host, + port=args.port, + stats_class=None, + config_path=args.config_path, + max_steps=args.max_steps, + split=args.split, + num_games=args.num_games, + emulator_console_port=console_port, + emulator_grpc_port=grpc_port, + ) + + +if __name__ == "__main__": + main() diff --git a/opentinker/environment/android_world/prompts.py b/opentinker/environment/android_world/prompts.py new file mode 100644 index 0000000..a89c184 --- /dev/null +++ b/opentinker/environment/android_world/prompts.py @@ -0,0 +1,61 @@ +# Prompts for AndroidWorld T3A Agent +# Extracted and adapted from android_world/agents/t3a.py + +PROMPT_PREFIX = """You are an agent who can operate an Android phone on behalf of a user. Based on user's goal/request, you may +- Answer back if the request/goal is a question (or a chat message), like user asks "What is my schedule for today?". +- Complete some tasks described in the requests/goals by performing actions (step by step) on the phone. + +When given a user request, you will try to complete it step by step. At each step, a list of descriptions for most UI elements on the current screen will be given to you (each element can be specified by an index), together with a history of what you have done in previous steps. Based on these pieces of information and the goal, you must choose to perform one of the action in the following list (action description followed by the JSON format) by outputing the action in the correct JSON format. +- If you think the task has been completed, finish the task by using the status action with complete as goal_status: {{"action_type": "status", "goal_status": "complete"}} +- If you think the task is not feasible (including cases like you don't have enough information or can not perform some necessary actions), finish by using the `status` action with infeasible as goal_status: {{"action_type": "status", "goal_status": "infeasible"}} +- Answer user's question: {{"action_type": "answer", "text": ""}} +- Click/tap on a UI element (specified by its index) on the screen: {{"action_type": "click", "index": }}. +- Long press on a UI element (specified by its index) on the screen: {{"action_type": "long_press", "index": }}. +- Type text into an editable text field (specified by its index), this action contains clicking the text field, typing in the text and pressing the enter, so no need to click on the target field to start: {{"action_type": "input_text", "text": , "index": }} +- Press the Enter key: {{"action_type": "keyboard_enter"}} +- Navigate to the home screen: {{"action_type": "navigate_home"}} +- Navigate back: {{"action_type": "navigate_back"}} +- Scroll the screen or a scrollable UI element in one of the four directions, use the same numeric index as above if you want to scroll a specific UI element, leave it empty when scroll the whole screen: {{"action_type": "scroll", "direction": , "index": }} +- Open an app (nothing will happen if the app is not installed): {{"action_type": "open_app", "app_name": }} +- Wait for the screen to update: {{"action_type": "wait"}} +""" + +GUIDANCE = """Here are some useful guidelines you need to follow: +General +- Usually there will be multiple ways to complete a task, pick the easiest one. Also when something does not work as expected (due to various reasons), sometimes a simple retry can solve the problem, but if it doesn't (you can see that from the history), try to switch to other solutions. +- Sometimes you may need to navigate the phone to gather information needed to complete the task, for example if user asks "what is my schedule tomorrow", then you may want to open the calendar app (using the `open_app` action), look up information there, answer user's question (using the `answer` action) and finish (using the `status` action with complete as goal_status). +- For requests that are questions (or chat messages), remember to use the `answer` action to reply to user explicitly before finish! Merely displaying the answer on the screen is NOT sufficient (unless the goal is something like "show me ..."). +- If the desired state is already achieved (e.g., enabling Wi-Fi when it's already on), you can just complete the task. +Action Related +- Use the `open_app` action whenever you want to open an app (nothing will happen if the app is not installed), do not use the app drawer to open an app unless all other ways have failed. +- Use the `input_text` action whenever you want to type something (including password) instead of clicking characters on the keyboard one by one. Sometimes there is some default text in the text field you want to type in, remember to delete them before typing. +- For `click`, `long_press` and `input_text`, the index parameter you pick must be VISIBLE in the screenshot and also in the UI element list given to you (some elements in the list may NOT be visible on the screen so you can not interact with them). +- Consider exploring the screen by using the `scroll` action with different directions to reveal additional content. +- The direction parameter for the `scroll` action can be confusing sometimes as it's opposite to swipe, for example, to view content at the bottom, the `scroll` direction should be set to "down". It has been observed that you have difficulties in choosing the correct direction, so if one does not work, try the opposite as well. +Text Related Operations +- Normally to select some text on the screen: Enter text selection mode by long pressing the area where the text is, then some of the words near the long press point will be selected (highlighted with two pointers indicating the range) and usually a text selection bar will also appear with options like `copy`, `paste`, `select all`, etc. Select the exact text you need. Usually the text selected from the previous step is NOT the one you want, you need to adjust the range by dragging the two pointers. If you want to select all text in the text field, simply click the `select all` button in the bar. +- At this point, you don't have the ability to drag something around the screen, so in general you can not select arbitrary text. +- To delete some text: the most traditional way is to place the cursor at the right place and use the backspace button in the keyboard to delete the characters one by one (can long press the backspace to accelerate if there are many to delete). Another approach is to first select the text you want to delete, then click the backspace button in the keyboard. +- To copy some text: first select the exact text you want to copy, which usually also brings up the text selection bar, then click the `copy` button in bar. +- To paste text into a text box, first long press the text box, then usually the text selection bar will appear with a `paste` button in it. +- When typing into a text field, sometimes an auto-complete dropdown list will appear. This usually indicating this is a enum field and you should try to select the best match by clicking the corresponding one in the list. +""" + +# Full template for initial prompt (used in raw_prompt) +INITIAL_ACTION_SELECTION_PROMPT_TEMPLATE = ( + PROMPT_PREFIX + + "\nThe current user goal/request is: {goal}" + + "\n\nHere is a history of what you have done so far:\n{history}" + + "\n\nHere is a list of descriptions for some UI elements on the current screen:\n{ui_elements_description}\n" + + GUIDANCE + + "{additional_guidelines}" + + "\n\nNow output an action from the above list in the correct JSON format, following the reason why you do that. Your answer should look like:\n" + "Reason: ...\nAction: {{'action_type':...}}\n\n" + "Your Answer:\n" +) + +# Simplified template for subsequent turns in multi-turn conversation +ACTION_SELECTION_PROMPT_TEMPLATE = ( + "\nThe current user goal/request is: {goal}" + + "\n\nHere is a list of descriptions for some UI elements on the current screen:\n{ui_elements_description}\n" +) \ No newline at end of file diff --git a/opentinker/environment/base_game.py b/opentinker/environment/base_game.py index c115b96..7d8e650 100755 --- a/opentinker/environment/base_game.py +++ b/opentinker/environment/base_game.py @@ -72,6 +72,10 @@ def generate_initial_state(self): return {"initial_moves": self._random_moves()} """ + # Agent loop name for server rollout (used when building get_config()). + # Override to "android_agent" etc. for task-specific agent loops. + agent_loop_name: str = "generic_agent" + # ========================================================================= # REQUIRED: Game Logic Methods # ========================================================================= diff --git a/opentinker/environment/base_game_environment.py b/opentinker/environment/base_game_environment.py index f3c189d..149efd0 100755 --- a/opentinker/environment/base_game_environment.py +++ b/opentinker/environment/base_game_environment.py @@ -278,6 +278,9 @@ def get_config(self) -> Dict[str, Any]: config = {} if self._interaction_config_path: + default_agent_loop = getattr( + self.game_class, "agent_loop_name", "generic_agent" + ) config["actor_rollout_ref"] = { "rollout": { "multi_turn": { @@ -285,7 +288,7 @@ def get_config(self) -> Dict[str, Any]: # Include content for cross-node transmission "interaction_config_content": self._interaction_config_content, }, - "agent": {"default_agent_loop": "generic_agent"}, + "agent": {"default_agent_loop": default_agent_loop}, }, } diff --git a/opentinker/environment/gym_environment_interaction.py b/opentinker/environment/gym_environment_interaction.py old mode 100755 new mode 100644 index 0042c07..b6b56c0 --- a/opentinker/environment/gym_environment_interaction.py +++ b/opentinker/environment/gym_environment_interaction.py @@ -22,8 +22,11 @@ The environment can be either local (imported directly) or remote (via HTTP API). """ +import asyncio import logging import os +import re +import threading import zlib from urllib.parse import urlparse, urlunparse from typing import Any, Optional, Callable @@ -33,9 +36,21 @@ from verl.interactions.base import BaseInteraction +# Try to import Ray for worker ID detection +try: + import ray + RAY_AVAILABLE = True +except ImportError: + RAY_AVAILABLE = False + logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +# Module-level cache for worker_id and bound_endpoint (persists across all instances in the same process) +# Key: (env_shards, base_port) tuple, Value: (worker_id, bound_endpoint) tuple +_worker_endpoint_cache: dict[tuple[int, int], tuple[int, str]] = {} +_cache_lock = threading.Lock() + class GymEnvironmentInteraction(BaseInteraction): """Interaction class for OpenAI Gym-like environments. @@ -47,6 +62,11 @@ class GymEnvironmentInteraction(BaseInteraction): Configuration options: - env_endpoint: HTTP endpoint for remote environment API - env_shards: Number of shards (servers on consecutive ports) + - bind_worker_to_endpoint: If True, each worker is bound to a specific endpoint + (1-to-1 worker <-> endpoint). Requires worker count == env_shards. Worker ID is + detected from Ray actor name (e.g., "agent_loop_worker_0" -> worker_id=0). + Worker ID and endpoint are cached at module level, persisting across all + instances in the same process. - env_factory: Callable that creates a local environment instance - max_steps: Maximum number of steps per episode - observation_template: Template for formatting observations as messages @@ -63,7 +83,7 @@ class GymEnvironmentInteraction(BaseInteraction): env_shards: 8 # Will use ports 8091..8098 max_steps: 100 """ - + def __init__(self, config: dict): super().__init__(config) self.env_endpoint: Optional[str] = config.get("env_endpoint") @@ -76,6 +96,13 @@ def __init__(self, config: dict): # Job ID for statistics isolation when using shared game servers self.job_id: str = config.get("job_id", "default") + # When True: bind this worker to a specific endpoint (1-to-1 worker <-> endpoint) + self.bind_worker_to_endpoint: bool = bool(config.get("bind_worker_to_endpoint", False)) + + # Worker-bound endpoint (set if bind_worker_to_endpoint is True) + self._bound_endpoint: Optional[str] = None + self._worker_id: Optional[int] = None + # Generate sharded endpoints if env_shards > 1 self.env_endpoints: Optional[list[str]] = None if self.env_endpoint and self.env_shards > 1: @@ -90,15 +117,96 @@ def __init__(self, config: dict): f"[GymEnvironmentInteraction] Sharded mode: {self.env_shards} shards " f"on ports {base_port}..{base_port + self.env_shards - 1}" ) + + # Bind worker to endpoint if requested + if self.bind_worker_to_endpoint: + if not self.env_endpoints: + raise ValueError( + "[GymEnvironmentInteraction] bind_worker_to_endpoint=True requires env_shards > 1" + ) + # Use module-level cache keyed by (env_shards, base_port) to persist across all instances + parsed = urlparse(self.env_endpoint) + base_port = parsed.port if parsed.port else 0 + cache_key = (self.env_shards, base_port) + + with _cache_lock: + if cache_key not in _worker_endpoint_cache: + # First time: detect worker_id and compute bound_endpoint + detected_id = self._detect_worker_id() + + if detected_id is None: + raise RuntimeError( + "[GymEnvironmentInteraction] bind_worker_to_endpoint=True but could not detect worker_id. " + "Make sure this is running in a Ray actor with name like 'agent_loop_worker_0', " + "or set RAY_WORKER_ID environment variable." + ) + if detected_id < 0 or detected_id >= len(self.env_endpoints): + raise ValueError( + f"[GymEnvironmentInteraction] worker_id={detected_id} is out of range " + f"for {len(self.env_endpoints)} endpoints. Ensure agent_num_workers == env_shards." + ) + bound_endpoint = self.env_endpoints[detected_id] + _worker_endpoint_cache[cache_key] = (detected_id, bound_endpoint) + logger.info( + f"[GymEnvironmentInteraction] Detected worker_id={detected_id} " + f"bound to endpoint={bound_endpoint} " + f"(module-level cache, persists across all instances in this process)" + ) + # Use cached values for this instance + self._worker_id, self._bound_endpoint = _worker_endpoint_cache[cache_key] # Session storage: maps instance_id to environment state self._instance_dict: dict[str, dict[str, Any]] = {} # For local environments, we can store the env objects self._local_envs: dict[str, Any] = {} + + def _detect_worker_id(self) -> Optional[int]: + """Detect worker ID from Ray actor name using get_actor_name() method. + + This uses the correct Ray API: runtime_context.get_actor_name() + + Returns: + Worker ID (0-based) if detected, None otherwise. + """ + if not RAY_AVAILABLE: + return None + + try: + runtime_context = ray.get_runtime_context() + + # Use get_actor_name() method - this is the correct Ray API + actor_name = runtime_context.get_actor_name() + if actor_name: + match = re.search(r"agent_loop_worker_(\d+)", str(actor_name)) + if match: + worker_id = int(match.group(1)) + logger.info(f"[GymEnvironmentInteraction] Detected worker_id={worker_id} from actor_name={actor_name}") + return worker_id + + # Fallback: Try environment variable + worker_id_env = os.environ.get("RAY_WORKER_ID") + if worker_id_env: + try: + worker_id = int(worker_id_env) + logger.info(f"[GymEnvironmentInteraction] Using worker_id={worker_id} from RAY_WORKER_ID env var") + return worker_id + except ValueError: + logger.warning(f"[GymEnvironmentInteraction] Invalid RAY_WORKER_ID value: {worker_id_env}") + + except Exception as e: + logger.debug(f"[GymEnvironmentInteraction] Failed to detect worker_id: {e}") + + return None def _get_endpoint(self, instance_id: str) -> str: - """Get the endpoint for this instance_id (supports sharding).""" + """Get the endpoint for this instance_id (supports sharding). + + If bind_worker_to_endpoint is True: return the bound endpoint for this worker. + Otherwise: hash instance_id to pick shard (may collide). + """ + if self.bind_worker_to_endpoint and self._bound_endpoint: + return self._bound_endpoint if self.env_endpoints: # Sharded mode: hash instance_id to pick shard idx = zlib.crc32(instance_id.encode("utf-8")) % len(self.env_endpoints) diff --git a/opentinker/scheduler/config/scheduler.yaml b/opentinker/scheduler/config/scheduler.yaml old mode 100755 new mode 100644 index dbd2c60..247f558 --- a/opentinker/scheduler/config/scheduler.yaml +++ b/opentinker/scheduler/config/scheduler.yaml @@ -16,6 +16,9 @@ num_ports: 50 # Port for the scheduler server itself scheduler_port: 8765 +# Directory for training job stdout/stderr logs +logs_dir: ./scheduler_logs + # Authentication settings enable_auth: false # Set to false to disable authentication diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 3680053..2a1824d 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -1037,6 +1037,8 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: env = os.environ.copy() # Set CUDA_VISIBLE_DEVICES to comma-separated list of GPU IDs env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, job.gpu_ids)) + # Pass job_id to agent loop for per-client trace subdirectory isolation + env["ROLLOUT_TRACE_JOB_ID"] = job.job_id # Build command line arguments from config cmd = [ @@ -1078,6 +1080,29 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: f"Job {job.job_id}: Ignoring rollout_n={rollout_n} (not in GRPO mode)" ) + # Forward KL divergence parameters from client config + kl_config = job.config.get("kl", {}) + if kl_config: + use_kl_in_reward = kl_config.get("use_kl_in_reward") + if use_kl_in_reward is not None: + cmd.append(f"algorithm.use_kl_in_reward={str(use_kl_in_reward).lower()}") + logger.info(f"Job {job.job_id}: āœ“ KL use_kl_in_reward={use_kl_in_reward}") + + use_kl_loss = kl_config.get("use_kl_loss") + if use_kl_loss is not None: + cmd.append(f"actor_rollout_ref.actor.use_kl_loss={str(use_kl_loss).lower()}") + logger.info(f"Job {job.job_id}: āœ“ KL use_kl_loss={use_kl_loss}") + + kl_loss_coef = kl_config.get("kl_loss_coef") + if kl_loss_coef is not None: + cmd.append(f"actor_rollout_ref.actor.kl_loss_coef={kl_loss_coef}") + logger.info(f"Job {job.job_id}: āœ“ KL kl_loss_coef={kl_loss_coef}") + + kl_loss_type = kl_config.get("kl_loss_type") + if kl_loss_type is not None: + cmd.append(f"actor_rollout_ref.actor.kl_loss_type={kl_loss_type}") + logger.info(f"Job {job.job_id}: āœ“ KL kl_loss_type={kl_loss_type}") + # Forward LoRA parameters if enabled (lora_rank > 0) lora_config = job.config.get("lora", {}) lora_rank = lora_config.get("lora_rank", 0) diff --git a/opentinker/scripts/launch_scheduler.sh b/opentinker/scripts/launch_scheduler.sh old mode 100755 new mode 100644 index 581e7cf..1e1d602 --- a/opentinker/scripts/launch_scheduler.sh +++ b/opentinker/scripts/launch_scheduler.sh @@ -1,18 +1,19 @@ #!/bin/bash # Convenience script to launch the job scheduler -# Set CUDA 12.8 environment explicitly -export CUDA_HOME=$HOME/local/cuda-12.8 -export PATH=$CUDA_HOME/bin:$PATH -export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH +# Set CUDA environment explicitly (adjusted for current user) +# export CUDA_HOME=/usr/local/cuda +# export PATH=$CUDA_HOME/bin:$PATH +# export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export ROLLOUT_TRACE_DIR="/home/haofeiy2/OpenTinker/traces" -export NVCC_EXECUTABLE=$CUDA_HOME/bin/nvcc +export ROLLOUT_TRACE_DIR="${ROLLOUT_TRACE_DIR:-./traces}" +mkdir -p "$ROLLOUT_TRACE_DIR" +# export NVCC_EXECUTABLE=$CUDA_HOME/bin/nvcc export TORCH_CUDA_ARCH_LIST="9.0" export FLASHINFER_HOMOGENEOUS_MS=1 # Default configuration -AVAILABLE_GPUS="[0,1,2,3,4,5,6,7,8,9]" +AVAILABLE_GPUS="[0,1,2,3]" PORT_RANGE="null" # Set to null for auto-detection NUM_PORTS=200 SCHEDULER_PORT=8780 diff --git a/opentinker/scripts/run_android.sh b/opentinker/scripts/run_android.sh new file mode 100755 index 0000000..e48d2d5 --- /dev/null +++ b/opentinker/scripts/run_android.sh @@ -0,0 +1,257 @@ +#!/bin/bash +# AndroidWorld Training Script (Multi-Turn, Multi-Emulator) +# +# This script runs AndroidWorld RL training with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (4 terminals): +# Terminal 1: bash run_android.sh scheduler +# Terminal 2: bash run_android.sh simulator # Android Emulator (start BEFORE env) +# Terminal 3: bash run_android.sh env +# Terminal 4: bash run_android.sh client +# +# Prerequisites: +# - Android SDK, AVD "AndroidWorldAvd" (or set AVD_NAME), and emulator in PATH +# - See docs/android_world_multiturn.md for environment setup + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT=9780 +ENV_PORT=9092 +GPUS="${GPUS:-[0,1,2,3]}" +NUM_GPUS="${NUM_GPUS:-4}" # For tensor_model_parallel_size (model spans N GPUs) + +# Multi-emulator configuration +# Set NUM_EMULATORS to match NUM_GPUS for true parallelism +NUM_EMULATORS="${NUM_EMULATORS:-4}" + +# Emulator (simulator) base ports +AVD_NAME="${AVD_NAME:-AndroidWorldAvd}" +# Console ports: 5556, 5558, 5560, 5562 (each +2 because ADB uses console+1) +EMULATOR_BASE_CONSOLE_PORT="${EMULATOR_BASE_CONSOLE_PORT:-5556}" +# gRPC ports: 8554, 8555, 8556, 8557 +EMULATOR_BASE_GRPC_PORT="${EMULATOR_BASE_GRPC_PORT:-8554}" +# EMULATOR_HEADLESS=1 -> -no-window -no-audio +EMULATOR_HEADLESS="${EMULATOR_HEADLESS:-1}" +# EMULATOR_NO_KVM=1 -> no "sg kvm", add -accel off (slow, for hosts without KVM) +EMULATOR_NO_KVM="${EMULATOR_NO_KVM:-0}" + +# Fix vLLM v1 cumem allocator issue +export VLLM_DISABLE_SLEEP_MODE=1 + +# Model path (set to your model path) +MODEL_PATH="${MODEL_PATH:-Qwen/Qwen2.5-3B-Instruct}" + +# OpenTinker root (relative to this script: opentinker/scripts/run_android.sh) +OPENTINKER_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +# Activate conda environment (adjust to your setup) +if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/anaconda3/etc/profile.d/conda.sh" +elif [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +fi +# conda activate + +# Change to OpenTinker directory +cd "$OPENTINKER_ROOT" + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + setup-avds) + echo "========================================" + echo "Creating $NUM_EMULATORS AVDs for parallel training" + echo "========================================" + echo "" + echo "This will create AVDs named: ${AVD_NAME}_0, ${AVD_NAME}_1, ..., ${AVD_NAME}_$((NUM_EMULATORS-1))" + echo "" + + # Detect system image (x86_64 or arm64-v8a) + SYSTEM_IMAGE="${SYSTEM_IMAGE:-system-images;android-33;google_apis;x86_64}" + echo "Using system image: $SYSTEM_IMAGE" + echo "" + + for i in $(seq 0 $((NUM_EMULATORS - 1))); do + AVD_NAME_I="${AVD_NAME}_${i}" + echo "Creating AVD: $AVD_NAME_I" + + # Check if AVD already exists + if avdmanager list avd | grep -q "Name: $AVD_NAME_I"; then + echo " AVD $AVD_NAME_I already exists, skipping..." + else + echo "no" | avdmanager create avd \ + --name "$AVD_NAME_I" \ + --package "$SYSTEM_IMAGE" \ + --device "pixel_6" \ + --force + echo " Created $AVD_NAME_I" + fi + done + + echo "" + echo "Done! Created $NUM_EMULATORS AVDs." + echo "You can now run: bash run_android.sh simulator" + ;; + + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + simulator|2) + echo "========================================" + echo "Step 2: Starting $NUM_EMULATORS Android Emulators" + echo " AVD base name: ${AVD_NAME}_0 ... ${AVD_NAME}_$((NUM_EMULATORS-1))" + echo " Base Console Port=$EMULATOR_BASE_CONSOLE_PORT" + echo " Base gRPC Port=$EMULATOR_BASE_GRPC_PORT" + echo "========================================" + echo "" + echo "IMPORTANT: Before starting, ensure NO other emulators are running!" + echo " Check: adb devices" + echo " Kill all emulators: adb emu kill (or close emulator windows)" + echo "" + echo "Starting $NUM_EMULATORS emulators:" + for i in $(seq 0 $((NUM_EMULATORS - 1))); do + CONSOLE_PORT=$((EMULATOR_BASE_CONSOLE_PORT + i * 2)) + GRPC_PORT=$((EMULATOR_BASE_GRPC_PORT + i)) + echo " Emulator $i: console=$CONSOLE_PORT (ADB=$((CONSOLE_PORT + 1))), grpc=$GRPC_PORT" + done + echo "" + echo "Ensure the env server is started AFTER all emulators are fully booted." + echo "" + + # Check if AVDs exist + echo "Checking AVDs..." + for i in $(seq 0 $((NUM_EMULATORS - 1))); do + AVD_NAME_I="${AVD_NAME}_${i}" + if ! avdmanager list avd 2>/dev/null | grep -q "Name: $AVD_NAME_I"; then + echo "ERROR: AVD '$AVD_NAME_I' not found!" + echo "Run 'bash run_android.sh setup-avds' first to create the AVDs." + exit 1 + fi + done + echo "All AVDs found." + echo "" + + # Start all emulators in background + PIDS=() + for i in $(seq 0 $((NUM_EMULATORS - 1))); do + AVD_NAME_I="${AVD_NAME}_${i}" + CONSOLE_PORT=$((EMULATOR_BASE_CONSOLE_PORT + i * 2)) + GRPC_PORT=$((EMULATOR_BASE_GRPC_PORT + i)) + + BASE="emulator -avd $AVD_NAME_I -no-snapshot -port $CONSOLE_PORT -grpc $GRPC_PORT" + if [ "$EMULATOR_NO_KVM" = "1" ]; then + CMD="$BASE -no-window -no-audio -accel off" + elif [ "$EMULATOR_HEADLESS" = "1" ]; then + CMD="$BASE -no-window -no-audio" + else + CMD="$BASE" + fi + + echo "Starting emulator $i ($AVD_NAME_I): $CMD" + if [ "$EMULATOR_NO_KVM" = "1" ]; then + $CMD & + else + sg kvm -c "$CMD" & + fi + PIDS+=($!) + sleep 2 # Wait a bit between emulator starts + done + + echo "" + echo "All $NUM_EMULATORS emulators started. PIDs: ${PIDS[*]}" + echo "Waiting for all emulators... Press Ctrl+C to stop." + wait + ;; + + env|3) + echo "========================================" + echo "Step 3: Starting AndroidWorld Environment Server" + echo " Shards: $NUM_EMULATORS (ports $ENV_PORT..$((ENV_PORT + NUM_EMULATORS - 1)))" + echo " Emulator base ports: console=$EMULATOR_BASE_CONSOLE_PORT, grpc=$EMULATOR_BASE_GRPC_PORT" + echo "========================================" + echo "Make sure all $NUM_EMULATORS Android Emulators are running first." + echo "" + python opentinker/environment/android_world/android_world_server.py \ + --port $ENV_PORT \ + --shards $NUM_EMULATORS \ + --emulator_base_console_port $EMULATOR_BASE_CONSOLE_PORT \ + --emulator_base_grpc_port $EMULATOR_BASE_GRPC_PORT \ + --split train \ + --max_steps 50 + ;; + + client|4) + echo "========================================" + echo "Step 4: Running AndroidWorld RL Client" + echo " Emulators: $NUM_EMULATORS (parallel rollouts)" + echo " GPUs: $NUM_GPUS (tensor parallelism)" + echo "========================================" + # Multi-emulator parallel training: + # - batch_size=NUM_GPUS: satisfies batch_size >= num_gpus for data partitioning + # - agent_num_workers=NUM_EMULATORS: parallel rollouts (one per emulator) + # - env_shards=NUM_EMULATORS: routes requests to different env servers/emulators + # - num_gpus=NUM_GPUS: model tensor parallelism (solves OOM) + python opentinker/client/android_world_rl.py \ + tokenizer_path=$MODEL_PATH \ + batch_size=$NUM_GPUS \ + val_batch_size=$NUM_GPUS \ + rollout_n=1 \ + adv_estimator=gae \ + agent_num_workers=$NUM_EMULATORS \ + num_steps=1000 \ + save_freq=50 \ + test_freq=10 \ + num_gpus=$NUM_GPUS \ + scheduler_url=http://0.0.0.0:$SCHEDULER_PORT \ + interaction.config.env_port=$ENV_PORT \ + interaction.config.env_host=0.0.0.0 \ + interaction.config.env_shards=$NUM_EMULATORS + ;; + + *) + echo "AndroidWorld Training Script (Multi-Turn, Multi-Emulator)" + echo "" + echo "Usage: $0 {setup-avds|scheduler|simulator|env|client}" + echo " $0 {1|2|3|4}" + echo "" + echo "=== First Time Setup ===" + echo " $0 setup-avds # Create $NUM_EMULATORS AVDs for parallel training" + echo "" + echo "=== For Training (4 terminals) ===" + echo " Terminal 1: $0 scheduler # Start scheduler (port $SCHEDULER_PORT)" + echo " Terminal 2: $0 simulator # Start $NUM_EMULATORS Android Emulators (start BEFORE env)" + echo " Terminal 3: $0 env # Start $NUM_EMULATORS env server shards (ports $ENV_PORT..$((ENV_PORT+NUM_EMULATORS-1)))" + echo " Terminal 4: $0 client # Start RL training client" + echo "" + echo "Multi-Emulator Configuration (env vars):" + echo " NUM_EMULATORS=$NUM_EMULATORS # Number of parallel emulators" + echo " AVD_NAME=$AVD_NAME # AVD base name (creates ${AVD_NAME}_0, ${AVD_NAME}_1, ...)" + echo " EMULATOR_BASE_CONSOLE_PORT=$EMULATOR_BASE_CONSOLE_PORT # Base console port" + echo " EMULATOR_BASE_GRPC_PORT=$EMULATOR_BASE_GRPC_PORT # Base gRPC port" + echo " EMULATOR_HEADLESS=1 # Headless: -no-window -no-audio" + echo " EMULATOR_NO_KVM=1 # No KVM: -accel off (slow, for containers/no KVM)" + echo "" + echo "IMPORTANT: Before running, ensure no other emulators are running!" + echo " Check: adb devices" + echo " Kill all: adb emu kill" + echo "" + echo "Configuration:" + echo " SCHEDULER_PORT=$SCHEDULER_PORT" + echo " ENV_PORT=$ENV_PORT" + echo " NUM_EMULATORS=$NUM_EMULATORS" + echo " GPUS=$GPUS" + echo " NUM_GPUS=$NUM_GPUS" + echo " MODEL_PATH=$MODEL_PATH" + echo "" + echo "See docs/android_world_multiturn.md for Android SDK and AVD setup." + ;; +esac diff --git a/opentinker/server/agent.yaml b/opentinker/server/agent.yaml index adabd5a..aaafc1b 100755 --- a/opentinker/server/agent.yaml +++ b/opentinker/server/agent.yaml @@ -1,2 +1,5 @@ - name: generic_agent _target_: opentinker.server.generic_agent_loop.GenericAgentLoop + +- name: android_agent + _target_: opentinker.server.android_agent_loop.AndroidAgentLoop diff --git a/opentinker/server/android_agent_loop.py b/opentinker/server/android_agent_loop.py new file mode 100755 index 0000000..05b84d0 --- /dev/null +++ b/opentinker/server/android_agent_loop.py @@ -0,0 +1,1160 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Android Agent Loop for LLM-Environment Interaction. + +This module provides a simplified agent loop for multi-turn interactions +between LLMs and external environments (e.g., OpenAI Gym-like APIs). +Unlike ToolAgentLoop, this does not handle tool calls - the external +environment is treated as a conversational API that returns observations. +""" + +import copy +import json +import logging +import os +import re +import threading +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Optional +from uuid import uuid4 + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput +from verl.interactions.base import BaseInteraction +from verl.interactions.utils.interaction_registry import ( + initialize_interactions_from_config, +) +from verl.utils.profiler import simple_timer +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def _deserialize_images(image_data): + """Deserialize PIL Images if they're still in serialized dict form. + + This handles the case where images arrive as {'__type__': 'PIL.Image', '__data__': base64...} + instead of actual PIL Image objects. + + Args: + image_data: List of images (either PIL Images or serialized dicts) + + Returns: + List of PIL Image objects + """ + if not image_data: + return image_data + + from PIL import Image + import base64 + import io + + result = [] + for img in image_data: + if isinstance(img, dict) and img.get("__type__") == "PIL.Image": + # Deserialize from base64-encoded PNG + data_bytes = base64.b64decode(img["__data__"]) + buffer = io.BytesIO(data_bytes) + result.append(Image.open(buffer).copy()) + elif hasattr(img, "save") and hasattr(img, "mode"): + # Already a PIL Image + result.append(img) + else: + # Unknown type, keep as is + result.append(img) + return result + + +class AndroidAgentState(Enum): + """States for the Android agent loop.""" + + PENDING = "pending" # Initial state, preparing the prompt + GENERATING = "generating" # LLM is generating response + INTERACTING = "interacting" # Interacting with external environment + TERMINATED = "terminated" # Rollout complete + + +class AndroidAgentData: + """Encapsulates all state variables for the Android agent loop. + + This is similar to AgentData in tool_agent_loop.py but without tool-specific fields. + """ + + def __init__( + self, + messages: list[dict[str, Any]], + metrics: dict[str, Any], + request_id: str, + interaction: Optional[BaseInteraction] = None, + interaction_kwargs: Optional[dict[str, Any]] = None, + image_data: Optional[list[Any]] = None, + ): + self.messages = messages + self.metrics = metrics + self.request_id = request_id + self.interaction = interaction + self.interaction_kwargs = interaction_kwargs or {} + + # Multimodal data (images/videos for VL models) + self.image_data = image_data + + # Token sequences + # prompt_ids: full accumulated sequence (system, user, assistant, user, ...) for final output/loss + self.prompt_ids: list[int] = [] + # generation_prompt_ids: what the model sees each time (system + latest user only) + self.generation_prompt_ids: list[int] = [] + self.response_ids: list[int] = [] + self.response_mask: list[int] = [] + self.response_logprobs: list[float] = [] + + # Turn tracking + self.user_turns = 0 + self.assistant_turns = 0 + + # Reward tracking (for turn-level rewards, accumulated for final reward) + self.turn_scores: list[float] = [] + + # Per-turn data for per-turn training mode + # Each entry: {generation_prompt_ids, response_ids, response_logprobs, reward} + self.per_turn_data: list[dict[str, Any]] = [] + + # Extra fields for additional data + self.extra_fields: dict[str, Any] = {} + + +# @register("Android_agent") +class AndroidAgentLoop(AgentLoopBase): + """Android agent loop for LLM-environment interaction. + + This agent loop handles multi-turn conversations between an LLM and + an external environment. The environment is accessed through a + BaseInteraction subclass that implements the generate_response method. + + State Machine: + PENDING -> GENERATING -> INTERACTING -> GENERATING -> ... -> TERMINATED + + Response Mask: + - mask=1: LLM generated tokens (included in loss computation) + - mask=0: Environment observations, system prompt, padding (excluded from loss) + + Reward Attribution: + - Final Reward: The cumulative reward is placed at the last response token position + """ + + # Trace saving configuration + _trace_output_dir: Optional[str] = None + _trace_count: int = 0 + _trace_lock = threading.Lock() # Thread-safe counter (within single process) + _save_traces: bool = False + _process_id: Optional[int] = None # Track process ID for multi-process training + + @classmethod + def init_class(cls, config, tokenizer, processor, **kwargs): + if cls._class_initialized: + return + cls._class_initialized = True + print("Performing class-level AndroidAgentLoop initialization") + + cls.tokenizer = tokenizer + cls.processor = processor + cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns + cls.max_assistant_turns = ( + config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + ) + + cls.apply_chat_template_kwargs = config.data.get( + "apply_chat_template_kwargs", {} + ) + cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length + cls.response_length = config.actor_rollout_ref.rollout.response_length + + # Per-turn token limit (optional, None means no per-turn limit) + cls.max_tokens_per_turn = config.actor_rollout_ref.rollout.multi_turn.get( + "max_tokens_per_turn", None + ) + + # Per-turn training mode: each interaction turn becomes a separate training sample + # instead of concatenating all turns into one long sequence + cls.per_turn_training = config.actor_rollout_ref.rollout.multi_turn.get( + "per_turn_training", False + ) + # Per-turn reward gamma: when > 0, replace immediate turn rewards with + # discounted cumulative returns so earlier turns can sense final outcome + cls.per_turn_reward_gamma = config.actor_rollout_ref.rollout.multi_turn.get( + "per_turn_reward_gamma", 0.0 + ) + if cls.per_turn_training: + print( + f"[AndroidAgentLoop] Per-turn training mode ENABLED: each turn becomes a separate training sample" + f" (reward_gamma={cls.per_turn_reward_gamma})" + ) + + # Pre-compute system prompt tokens for later stripping + cls.system_prompt = tokenizer.apply_chat_template( + [{}], + add_generation_prompt=False, + tokenize=True, + **cls.apply_chat_template_kwargs, + ) + + # Initialize interactions from config + # CROSS-NODE FIX: If interaction_config_content is available, recreate the temp file locally + # because the original path may point to a file on a different node's /tmp/ + cls.interaction_config_file = ( + config.actor_rollout_ref.rollout.multi_turn.interaction_config_path + ) + interaction_config_content = config.actor_rollout_ref.rollout.multi_turn.get( + "interaction_config_content", None + ) + + if interaction_config_content: + import tempfile + + # Create a local temp file with the content on THIS worker's node + fd, local_path = tempfile.mkstemp( + suffix=".yaml", prefix="interaction_config_worker_" + ) + with os.fdopen(fd, "w") as f: + f.write(interaction_config_content) + cls.interaction_config_file = local_path + print( + f"[AndroidAgentLoop] Created local interaction config from content: {local_path}" + ) + + if cls.interaction_config_file: + cls.interaction_map: dict[str, BaseInteraction] = ( + cls._initialize_interactions(cls.interaction_config_file) + ) + else: + cls.interaction_map = {} + + # Initialize trace saving + cls._trace_output_dir = os.environ.get("ROLLOUT_TRACE_DIR", None) + if cls._trace_output_dir: + # Create per-job subdirectory to isolate traces from different client tasks + job_id = os.environ.get("ROLLOUT_TRACE_JOB_ID", None) + if job_id: + cls._trace_output_dir = str(Path(cls._trace_output_dir) / f"job_{job_id}") + cls._save_traces = True + cls._process_id = os.getpid() # Store process ID for unique trace naming + Path(cls._trace_output_dir).mkdir(parents=True, exist_ok=True) + print( + f"[AndroidAgentLoop] Rollout trace saving ENABLED: {cls._trace_output_dir} (PID: {cls._process_id})" + ) + else: + cls._save_traces = False + print( + "[AndroidAgentLoop] Rollout trace saving DISABLED (set ROLLOUT_TRACE_DIR to enable)" + ) + + # Initialize Weave tracing on server side + # Enabled via WEAVE_PROJECT env var (e.g., "opentinker/Android-env") + # or via config.actor_rollout_ref.rollout.multi_turn.weave_project + weave_project = os.environ.get("WEAVE_PROJECT", None) + if weave_project is None: + weave_project = config.actor_rollout_ref.rollout.multi_turn.get( + "weave_project", None + ) + + if weave_project: + try: + from verl.utils.rollout_trace import RolloutTraceConfig + + experiment_name = config.actor_rollout_ref.rollout.multi_turn.get( + "experiment_name", "default" + ) + RolloutTraceConfig.init( + project_name=weave_project, + experiment_name=experiment_name, + backend="weave", + token2text=True, + ) + print( + f"[AndroidAgentLoop] Weave tracing ENABLED: project={weave_project}, experiment={experiment_name}" + ) + except ImportError: + print( + "[AndroidAgentLoop] WARNING: Weave not installed (pip install weave)" + ) + except Exception as e: + logger.warning(f"Failed to init Weave: {e}") + + @rollout_trace_op + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + """Run the agent loop for a single trajectory. + + Args: + sampling_params: LLM sampling parameters (temperature, top_p, etc.) + **kwargs: Dataset fields including 'raw_prompt', 'extra_info', etc. + + Returns: + AgentLoopOutput containing prompt_ids, response_ids, response_mask, etc. + """ + # breakpoint() + # Extract step if available (for trace naming) + step = kwargs.get("step", None) + # Extract messages from kwargs + if "raw_prompt" not in kwargs: + raise KeyError("raw_prompt is required in kwargs for agent loop") + + raw_prompt_value = kwargs["raw_prompt"] + # CRITICAL: Deep copy to prevent GRPO n-sample message accumulation bug! + # When GRPO samples the same prompt N times, each rollout MUST have its own + # independent messages list. Without deepcopy, all N rollouts share the same + # list reference, causing conversation history from all samples to accumulate. + if isinstance(raw_prompt_value, list): + messages = copy.deepcopy(raw_prompt_value) + elif isinstance(raw_prompt_value, dict): + messages = [copy.deepcopy(raw_prompt_value)] + else: + raise TypeError( + f"raw_prompt must be a list or dict, got {type(raw_prompt_value)}" + ) + + metrics = {} + + # Use a stable request_id if available to allow environment reuse on the server. + # extra_info.sample_id is often a stable index for the worker/sample. + stable_id = kwargs.get("extra_info", {}).get("sample_id") + if stable_id is not None: + request_id = str(stable_id) + else: + request_id = uuid4().hex + + # CRITICAL: Extract multimodal data (images) from kwargs for VL models + # This follows the verl pattern from single_turn_agent_loop.py + multi_modal_data_raw = kwargs.get("multi_modal_data") + print( + f"[AndroidAgentLoop DEBUG] multi_modal_data type: {type(multi_modal_data_raw)}, value: {multi_modal_data_raw!r:.200}" + ) + + if isinstance(multi_modal_data_raw, dict): + image_data = copy.deepcopy(multi_modal_data_raw.get("image", None)) + else: + image_data = None + + # Deserialize images if they're still in serialized dict form + # This handles cases where HTTP deserialization didn't fully complete + if image_data: + image_data = _deserialize_images(image_data) + + print( + f"[AndroidAgentLoop DEBUG] image_data type: {type(image_data)}, is_list: {isinstance(image_data, list)}" + ) + if image_data: + print( + f"[AndroidAgentLoop DEBUG] image_data[0] type: {type(image_data[0]) if len(image_data) > 0 else 'empty'}" + ) + + # Debug: Save images if SAVE_DEBUG_IMAGES is set + if image_data and os.environ.get("SAVE_DEBUG_IMAGES"): + await self._save_debug_images(image_data, request_id) + + # Initialize interaction if configured + interaction = None + interaction_kwargs = {} + if self.interaction_config_file: + interaction_kwargs = kwargs.get("extra_info", {}).get( + "interaction_kwargs", {} + ) + if not interaction_kwargs: + interaction_kwargs = kwargs.get("interaction_kwargs", {}) + + # Get interaction name - use default if not provided in data + if "name" not in interaction_kwargs: + # Use the first interaction from config as default + if self.interaction_map: + default_interaction_name = list(self.interaction_map.keys())[0] + interaction_kwargs["name"] = default_interaction_name + logger.info( + f"Using default interaction: {default_interaction_name}" + ) + else: + raise ValueError( + "No interactions configured in interaction_config_file" + ) + + interaction_name = interaction_kwargs["name"] + if interaction_name not in self.interaction_map: + raise ValueError( + f"Interaction '{interaction_name}' not found. Available: {list(self.interaction_map.keys())}" + ) + interaction = self.interaction_map[interaction_name] + await interaction.start_interaction(request_id, **interaction_kwargs) + + # Capture initial board state ONLY for Gomoku environment (not other environments) + initial_board_state = None + if interaction_name == "gomoku": # Only for Gomoku + if ( + hasattr(interaction, "_instance_dict") + and request_id in interaction._instance_dict + ): + initial_board_state = interaction._instance_dict[request_id].get( + "initial_board_state" + ) + + # Create agent data to track state + agent_data = AndroidAgentData( + messages=messages, + metrics=metrics, + request_id=request_id, + interaction=interaction, + interaction_kwargs=interaction_kwargs, + image_data=image_data, + ) + + # breakpoint() + + # Store initial board state if available + if initial_board_state: + agent_data.extra_fields["initial_board_state"] = initial_board_state + + # State machine loop + state = AndroidAgentState.PENDING + try: + while state != AndroidAgentState.TERMINATED: + if state == AndroidAgentState.PENDING: + state = await self._handle_pending_state( + agent_data, sampling_params + ) + elif state == AndroidAgentState.GENERATING: + state = await self._handle_generating_state( + agent_data, sampling_params + ) + elif state == AndroidAgentState.INTERACTING: + state = await self._handle_interacting_state(agent_data) + else: + logger.error(f"Invalid state: {state}") + state = AndroidAgentState.TERMINATED + finally: + # CRITICAL: Always finalize interaction to release resources + if agent_data.interaction is not None: + await agent_data.interaction.finalize_interaction(agent_data.request_id) + + # Finalize output + response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :] + prompt_ids = agent_data.prompt_ids[ + : len(agent_data.prompt_ids) - len(agent_data.response_mask) + ] + + # Truncate prompt_ids if they exceed steps * prompt_length (left truncate to keep recent context). + # Each generation uses prompt <= prompt_length; total prefix is at most steps * prompt_length. + max_total_prompt = self.prompt_length * max( + self.max_assistant_turns or 1, self.max_user_turns or 1 + ) + if len(prompt_ids) > max_total_prompt: + logger.warning( + f"[AndroidAgentLoop] Truncating prompt from {len(prompt_ids)} to {max_total_prompt} tokens (steps * prompt_length)" + ) + prompt_ids = prompt_ids[-max_total_prompt:] + + # Calculate final reward (sum of all turn scores) + # Return 0.0 if no turn scores collected - this prevents fallback to naive reward loop + # which expects ground_truth data that gym environments don't provide + final_reward = sum(agent_data.turn_scores) if agent_data.turn_scores else 0.0 + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=agent_data.response_mask[: self.response_length], + multi_modal_data={"image": agent_data.image_data} + if agent_data.image_data + else {}, + response_logprobs=agent_data.response_logprobs[: self.response_length] + if agent_data.response_logprobs + else None, + num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, + metrics=agent_data.metrics, + reward_score=final_reward, + extra_fields={}, + ) + # Explicitly set reward_extra_info with FIXED keys to ensure consistency + # across all workers. This prevents meta_info['reward_extra_keys'] conflicts + # when DataProto.concat() merges outputs from different workers. + # Using the same keys as GSM8K/Math reward functions for compatibility. + output.extra_fields["reward_extra_info"] = { + "acc": None, # Placeholder - will be filtered out in metrics computation + } + # Ensure env_info exists for all samples (even if empty) for consistent DataProto.concat + output.extra_fields["env_info"] = agent_data.extra_fields.get("env_info", []) + output.extra_fields["turn_scores"] = agent_data.turn_scores + # Add any other extra fields (except the ones we already set) + for key, value in agent_data.extra_fields.items(): + if key not in output.extra_fields: + output.extra_fields[key] = value + + # Build per-turn outputs for per-turn training mode + # Each turn becomes a separate training sample with its own prompt, response, and reward. + # This avoids concatenating all turns into one long sequence that exceeds context limits. + if self.per_turn_training and agent_data.per_turn_data: + # Drop the last per_turn_data entry if it has reward=0.0 (never went through INTERACTING). + # This happens when the loop terminates in GENERATING state (e.g., max_assistant_turns exceeded). + # A reward=0.0 sample provides no gradient signal and dilutes the actual rewards. + if ( + len(agent_data.per_turn_data) > 0 + and agent_data.per_turn_data[-1]['reward'] == 0.0 + and len(agent_data.per_turn_data) > len(agent_data.turn_scores) + ): + dropped = agent_data.per_turn_data.pop() + logger.info( + f"[AndroidAgentLoop] Dropped last per-turn sample with reward=0.0 " + f"(never reached INTERACTING state)" + ) + # Compute per-turn rewards: either raw immediate rewards or cumulative + # discounted returns (when per_turn_reward_gamma > 0). + # Cumulative returns propagate the final outcome signal (e.g., +10.0 for + # success) back to earlier turns, enabling cross-turn credit assignment + # that GAE alone cannot provide in per-turn mode. + immediate_rewards = [t['reward'] for t in agent_data.per_turn_data] + if self.per_turn_reward_gamma > 0 and len(immediate_rewards) > 1: + gamma = self.per_turn_reward_gamma + cumulative_returns = [0.0] * len(immediate_rewards) + G = 0.0 + for i in reversed(range(len(immediate_rewards))): + G = immediate_rewards[i] + gamma * G + cumulative_returns[i] = G + per_turn_rewards = cumulative_returns + else: + per_turn_rewards = immediate_rewards + + per_turn_agent_outputs = [] + for turn_data, reward in zip(agent_data.per_turn_data, per_turn_rewards): + turn_response_ids = turn_data['response_ids'][:self.response_length] + turn_logprobs = turn_data['response_logprobs'][:self.response_length] if turn_data['response_logprobs'] else None + turn_output = AgentLoopOutput( + prompt_ids=turn_data['generation_prompt_ids'], + response_ids=turn_response_ids, + response_mask=[1] * len(turn_response_ids), + response_logprobs=turn_logprobs, + multi_modal_data={"image": agent_data.image_data} if agent_data.image_data else {}, + reward_score=reward, + num_turns=1, + metrics=agent_data.metrics, + extra_fields={ + "reward_extra_info": {"acc": None}, + "env_info": [], + "turn_scores": [turn_data['reward']], # Keep original immediate reward for logging + }, + ) + per_turn_agent_outputs.append(turn_output) + output.extra_fields['per_turn_outputs'] = per_turn_agent_outputs + logger.info( + f"[AndroidAgentLoop] Per-turn training: {len(per_turn_agent_outputs)} turns " + f"(immediate_rewards: {[round(r, 4) for r in immediate_rewards]}, " + f"training_rewards: {[round(r, 4) for r in per_turn_rewards]})" + ) + + # Save rollout trace for verification (if enabled via ROLLOUT_TRACE_DIR env var) + if self._save_traces: + await self._save_rollout_trace(agent_data, output, request_id, step) + + return output + + async def _handle_pending_state( + self, + agent_data: AndroidAgentData, + sampling_params: dict[str, Any], + ) -> AndroidAgentState: + """Handle the pending state: tokenize the initial prompt.""" + if self.processor is not None: + raw_prompt = await self.loop.run_in_executor( + None, + lambda: self.processor.apply_chat_template( + agent_data.messages, + add_generation_prompt=True, + tokenize=False, + **self.apply_chat_template_kwargs, + ), + ) + # CRITICAL: Pass images to processor for VL models + model_inputs = self.processor( + text=[raw_prompt], + images=agent_data.image_data if agent_data.image_data else None, + return_tensors="pt", + ) + agent_data.prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist() + else: + agent_data.prompt_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + agent_data.messages, + add_generation_prompt=True, + tokenize=True, + **self.apply_chat_template_kwargs, + ), + ) + # First turn: generation sees the same prompt as full (system + initial user) + agent_data.generation_prompt_ids = list(agent_data.prompt_ids) + return AndroidAgentState.GENERATING + + async def _handle_generating_state( + self, + agent_data: AndroidAgentData, + sampling_params: dict[str, Any], + ) -> AndroidAgentState: + """Handle the generating state: generate LLM response. + + The generated tokens are marked with mask=1 (included in loss computation). + """ + import time + + # CONTEXT OVERFLOW PROTECTION: Check if we have enough room for generation + # This prevents the "max_tokens must be at least 1, got -X" error from vLLM + # which occurs when prompt_len exceeds max_model_len (especially for VL models + # where image tokens can be very large) + total_context_budget = self.prompt_length + self.response_length + min_generation_tokens = 16 # Minimum tokens needed for meaningful generation + + if len(agent_data.generation_prompt_ids) + min_generation_tokens > total_context_budget: + logger.warning( + f"[AndroidAgentLoop] Context overflow detected: prompt_len={len(agent_data.generation_prompt_ids)}, " + f"total_budget={total_context_budget}. Terminating early to avoid negative max_tokens error." + ) + # Add a placeholder response if none exists yet (so we have valid output) + if not agent_data.response_ids: + # Add EOS token as minimal response + eos_token_id = self.tokenizer.eos_token_id + if eos_token_id is not None: + agent_data.response_ids = [eos_token_id] + agent_data.prompt_ids.append(eos_token_id) + agent_data.response_mask.append(1) + return AndroidAgentState.TERMINATED + + print( + f"[AndroidAgentLoop DEBUG] _handle_generating_state START: request_id={agent_data.request_id}, prompt_len={len(agent_data.prompt_ids)}" + ) + start_time = time.time() + with simple_timer("generate_sequences", agent_data.metrics): + print( + f"[AndroidAgentLoop DEBUG] Calling server_manager.generate() with image_data={agent_data.image_data is not None}..." + ) + # CRITICAL: Pass image_data to vLLM for VL model inference + try: + output = await self.server_manager.generate( + request_id=agent_data.request_id, + prompt_ids=agent_data.generation_prompt_ids, + sampling_params=sampling_params, + image_data=agent_data.image_data, + ) + except Exception as e: + # Before re-raising: decode and print the full input for debugging + import sys + n_tokens = len(agent_data.generation_prompt_ids) + msg = ( + f"[AndroidAgentLoop] generate() failed: {type(e).__name__}: {e}\n" + f"[AndroidAgentLoop] prompt_ids len={n_tokens} (full accumulated multi-turn prompt; each BEGIN/END block below = one worker's dump)\n" + "--- BEGIN full decoded prompt (for debug) ---" + ) + sys.stderr.write(msg) + sys.stderr.flush() + try: + decoded = self.tokenizer.decode( + agent_data.generation_prompt_ids, skip_special_tokens=True + ) + sys.stderr.write(decoded) + sys.stderr.write("\n--- END full decoded prompt ---\ +") + sys.stderr.flush() + except Exception as decode_err: + sys.stderr.write(f"[AndroidAgentLoop] decode failed: {decode_err}\n") + sys.stderr.flush() + raise + elapsed = time.time() - start_time + print( + f"[AndroidAgentLoop DEBUG] server_manager.generate() COMPLETED in {elapsed:.2f}s, response_tokens={len(output.token_ids) if output else 0}" + ) + + agent_data.assistant_turns += 1 + response_token_ids = output.token_ids + response_log_probs = output.log_probs + + # Apply per-turn token limit if configured + if ( + self.max_tokens_per_turn + and len(response_token_ids) > self.max_tokens_per_turn + ): + logger.debug( + f"Truncating turn response from {len(response_token_ids)} to {self.max_tokens_per_turn} tokens" + ) + response_token_ids = response_token_ids[: self.max_tokens_per_turn] + if response_log_probs: + response_log_probs = response_log_probs[: self.max_tokens_per_turn] + + agent_data.response_ids = response_token_ids + agent_data.prompt_ids += agent_data.response_ids + agent_data.response_mask += [1] * len( + agent_data.response_ids + ) # mask=1 for LLM tokens + + if response_log_probs: + agent_data.response_logprobs += response_log_probs + + # Collect per-turn data for per-turn training mode + if self.per_turn_training: + agent_data.per_turn_data.append({ + 'generation_prompt_ids': list(agent_data.generation_prompt_ids), + 'response_ids': list(response_token_ids), + 'response_logprobs': list(response_log_probs) if response_log_probs else [], + 'reward': 0.0, # Will be updated in INTERACTING state when env provides reward + }) + + # Check termination conditions + if len(agent_data.response_ids) >= self.response_length: + return AndroidAgentState.TERMINATED + # Use > instead of >= so that max_assistant_turns=1 allows 1 generation + 1 step + # before terminating (instead of terminating immediately after first generation) + if ( + self.max_assistant_turns + and agent_data.assistant_turns > self.max_assistant_turns + ): + return AndroidAgentState.TERMINATED + # Similarly, max_user_turns=1 means user can ask once, then terminate after next generation + if self.max_user_turns and agent_data.user_turns > self.max_user_turns: + return AndroidAgentState.TERMINATED + + # Add assistant message to conversation history + if agent_data.interaction is not None: + assistant_message = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.decode( + agent_data.response_ids, skip_special_tokens=True + ), + ) + agent_data.messages.append( + {"role": "assistant", "content": assistant_message} + ) + return AndroidAgentState.INTERACTING + else: + # No interaction configured, terminate after first generation + return AndroidAgentState.TERMINATED + + async def _handle_interacting_state( + self, + agent_data: AndroidAgentData, + ) -> AndroidAgentState: + """Handle the interacting state: get response from external environment. + + The environment observation is tokenized and marked with mask=0 + (excluded from loss computation). + """ + # Call the interaction to get environment response + ( + should_terminate, + observation, + reward, + info, + ) = await agent_data.interaction.generate_response( + agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs + ) + agent_data.user_turns += 1 + + # Record turn-level reward (will be summed for final reward) + if reward is not None: + agent_data.turn_scores.append(reward) + + # Update per-turn reward for the last generation + if self.per_turn_training and agent_data.per_turn_data: + agent_data.per_turn_data[-1]['reward'] = reward + + # Store environment info under a SINGLE key to ensure consistent structure + # across all samples (avoids DataProto.concat assertion errors when different + # samples return different info keys) + if info: + # Append to list instead of overwriting (for multi-turn) + if "env_info" not in agent_data.extra_fields: + agent_data.extra_fields["env_info"] = [] + agent_data.extra_fields["env_info"].append(info) + + # observation may be "long_partshort_part": long part for generation prompt, short for prompt_ids/mask + if "" in observation: + observation_long, observation_short = observation.split("", 1) + else: + observation_long = observation + observation_short = observation + + # Construct user message from full observation (for message history) + add_messages: list[dict[str, Any]] = [{"role": "user", "content": observation_short}] + agent_data.messages.extend(add_messages) + + # Tokenize the user message (environment observation) + if self.processor is not None: + raw_user_response = await self.loop.run_in_executor( + None, + lambda: self.processor.apply_chat_template( + add_messages, + add_generation_prompt=True, + tokenize=False, + **self.apply_chat_template_kwargs, + ), + ) + model_inputs = self.processor(text=[raw_user_response], return_tensors="pt") + response_ids = model_inputs.pop("input_ids").squeeze(0).tolist() + else: + response_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + add_messages, add_generation_prompt=True, tokenize=True + ), + ) + + # Strip the system prompt tokens (they are duplicated from the full conversation) + response_ids = response_ids[len(self.system_prompt) :] + + # Check if adding these tokens would exceed response length + # In per-turn training mode, skip this check since each turn is trained independently + # and the accumulated response_mask length is irrelevant. + if not self.per_turn_training and len(agent_data.response_mask) + len(response_ids) >= self.response_length: + return AndroidAgentState.TERMINATED + + # Update full prompt_ids and response_mask (accumulated for final output/loss) + # mask=0 for environment observation tokens (not included in loss) + agent_data.prompt_ids += response_ids + agent_data.response_mask += [0] * len(response_ids) + + if agent_data.response_logprobs: + # Pad logprobs with 0.0 for observation tokens + agent_data.response_logprobs += [0.0] * len(response_ids) + + # Next generation: model sees system + LONG part of observation (full context for generation) + system_msg = next( + (m for m in agent_data.messages if m.get("role") == "system"), + agent_data.messages[0], + ) + minimal_messages = [system_msg, {"role": "user", "content": observation_long}] + if self.processor is not None: + raw_minimal = await self.loop.run_in_executor( + None, + lambda: self.processor.apply_chat_template( + minimal_messages, + add_generation_prompt=True, + tokenize=False, + **self.apply_chat_template_kwargs, + ), + ) + model_inputs = self.processor( + text=[raw_minimal], + images=agent_data.image_data if agent_data.image_data else None, + return_tensors="pt", + ) + agent_data.generation_prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist() + else: + agent_data.generation_prompt_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + minimal_messages, + add_generation_prompt=True, + tokenize=True, + **self.apply_chat_template_kwargs, + ), + ) + + if should_terminate: + return AndroidAgentState.TERMINATED + else: + return AndroidAgentState.GENERATING + + async def _save_debug_images(self, image_data: list, request_id: str): + """Save debug images to disk when SAVE_DEBUG_IMAGES env var is set. + + This helps verify that images are being correctly passed to the model. + Images are saved to the ROLLOUT_TRACE_DIR or /tmp/debug_images/ folder. + """ + import os + from pathlib import Path + + output_dir = Path(os.environ.get("ROLLOUT_TRACE_DIR", "/tmp/debug_images")) + output_dir.mkdir(parents=True, exist_ok=True) + + for idx, img in enumerate(image_data): + try: + # Handle PIL Images + if hasattr(img, "save"): + img_path = output_dir / f"debug_image_{request_id[:8]}_{idx}.png" + await self.loop.run_in_executor( + None, lambda p=img_path, im=img: im.save(str(p)) + ) + logger.debug(f"Saved debug image to {img_path}") + else: + logger.debug( + f"Image {idx} is of type {type(img)}, cannot save" + ) + except Exception as e: + logger.debug(f"Failed to save image {idx}: {e}") + + @classmethod + def _initialize_interactions(cls, interaction_config_file): + """Initialize interactions from configuration. + + Returns: + dict[str, BaseInteraction]: A dictionary mapping interaction names to instances. + """ + if interaction_config_file is None: + return {} + + interaction_map = initialize_interactions_from_config(interaction_config_file) + logger.info(f"Initialized interactions: {list(interaction_map.keys())}") + return interaction_map + + async def _save_rollout_trace( + self, + agent_data: AndroidAgentData, + output: AgentLoopOutput, + request_id: str, + step: Optional[int] = None, + ): + """Save rollout trace to JSON file for algorithm verification. + + Trace includes: + - Full conversation messages + - Token IDs and response mask + - Rewards and env info + - Decoded text (readable format) + - Per-turn board states (for Gomoku and similar environments) + """ + try: + # Use a sequential counter for ordered filenames + # Include request_id prefix to avoid collisions across workers + # (each worker process has its own _trace_count starting from 0) + req_short = request_id[:8] if request_id else "unknown" + with self._trace_lock: + self.__class__._trace_count += 1 + seq = self.__class__._trace_count + + # Format: _ (e.g. trace_000001_ab20ea37.json) + if step is not None: + trace_id = f"{seq:06d}_step{step:06d}_{req_short}" + else: + trace_id = f"{seq:06d}_{req_short}" + + # Decode text for readability + prompt_text = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.decode( + output.prompt_ids, skip_special_tokens=True + ), + ) + response_text = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.decode( + output.response_ids, skip_special_tokens=True + ), + ) + + # Extract per-turn board states from messages + per_turn_board_states = self._extract_per_turn_board_states( + agent_data.messages, agent_data + ) + + # Build trace data + trace_data = { + "trace_id": trace_id, + "request_id": request_id, + "timestamp": datetime.now().isoformat(), + # Conversation + "messages": agent_data.messages, + "initial_prompt": agent_data.messages[0] + if agent_data.messages + else None, + # Decoded text (readable) + "prompt_text": prompt_text, + "response_text": response_text, + # Token-level data + "prompt_ids": output.prompt_ids, + "response_ids": output.response_ids, + "response_mask": output.response_mask, + # Response mask analysis + "response_mask_analysis": { + "total_tokens": len(output.response_mask), + "llm_tokens": sum(output.response_mask), # mask=1 + "env_tokens": len(output.response_mask) + - sum(output.response_mask), # mask=0 + "llm_ratio": sum(output.response_mask) / len(output.response_mask) + if output.response_mask + else 0, + }, + # Per-turn board states (for Gomoku verification) + "per_turn_board_states": per_turn_board_states, + # Rewards + "reward_score": output.reward_score, + "turn_scores": agent_data.turn_scores, + "env_info": agent_data.extra_fields.get("env_info", []), + # Turn tracking + "num_user_turns": agent_data.user_turns, + "num_assistant_turns": agent_data.assistant_turns, + "total_turns": output.num_turns, + # Configuration (for verification) + # NOTE: response_length is the TOTAL response budget for entire multi-turn trajectory + # NOT per-turn! Each generation call gets max_tokens = max_model_len - current_prompt_len + "config": { + "response_length": self.response_length, # Total response budget (NOT per-turn) + "prompt_length": self.prompt_length, + "max_user_turns": self.max_user_turns, + "max_assistant_turns": self.max_assistant_turns, + "max_tokens_per_turn": self.max_tokens_per_turn, # Per-turn limit (None = no limit) + }, + # Metrics + "metrics": agent_data.metrics, + } + + # Save trace file to per-job directory + trace_file = Path(self._trace_output_dir) / f"trace_{trace_id}.json" + await self.loop.run_in_executor( + None, + lambda: trace_file.write_text( + json.dumps(trace_data, indent=2, default=str) + ), + ) + + # Also append to streaming JSONL file for easy batch processing + jsonl_file = Path(self._trace_output_dir) / "traces.jsonl" + json_line = json.dumps(trace_data, default=str) + "\n" + await self.loop.run_in_executor( + None, + lambda: jsonl_file.open("a").write(json_line), + ) + + logger.info(f"Saved trace {trace_id} to {trace_file}") + + except Exception as e: + import traceback + + logger.warning(f"Failed to save rollout trace: {e}") + print(traceback.format_exc()) + + def _extract_per_turn_board_states( + self, + messages: list[dict[str, Any]], + agent_data: AndroidAgentData, + ) -> list[dict[str, Any]]: + """Extract board states from each message for verification. + + This is ONLY for Gomoku environment. Other environments will get empty board states. + + For Gomoku, this verifies that the board state in the prompt matches the actual game state. + + Combines: + - Board states from env_info (ground truth from environment) + - Visual board extracted from message content + - Initial board state if available + + Returns: + List of dicts with turn info, role, board_state, and verification status. + Empty list for non-Gomoku environments. + """ + # Only extract board states for Gomoku environment + interaction_name = agent_data.interaction_kwargs.get("name", "") + if interaction_name != "gomoku": + return [] # No board state tracking for other environments + + per_turn_states = [] + + # Get env_info list (board states from environment) + env_info_list = agent_data.extra_fields.get("env_info", []) + initial_board_state = agent_data.extra_fields.get("initial_board_state") + + # Track which env_info entry corresponds to which user message + # env_info is added after each interaction (user turn) + env_info_idx = 0 + + for i, message in enumerate(messages): + role = message.get("role", "unknown") + content = message.get("content", "") + + # Extract visual board from message content + visual_board = self._extract_board_from_content(content) + + # Get structured board state from environment + structured_board = None + if role == "system" and initial_board_state: + # System message may reference initial state + structured_board = initial_board_state + elif role == "user" and env_info_idx < len(env_info_list): + # User messages (env observations) have board states in env_info + env_info = env_info_list[env_info_idx] + structured_board = ( + env_info.get("board_state") if isinstance(env_info, dict) else None + ) + env_info_idx += 1 + + # Verification: compare visual board with structured board + verification_status = None + if visual_board and structured_board: + expected_visual = structured_board.get("board_visual", "") + # Simple comparison: normalize whitespace + visual_normalized = " ".join(visual_board.split()) + expected_normalized = " ".join(expected_visual.split()) + verification_status = ( + "match" if visual_normalized == expected_normalized else "mismatch" + ) + + turn_info = { + "turn": i, + "role": role, + "visual_board_extracted": visual_board, + "structured_board_state": structured_board, + "verification_status": verification_status, + "message_content_preview": content[:300] + "..." + if len(content) > 300 + else content, + } + + per_turn_states.append(turn_info) + + return per_turn_states + + def _extract_board_from_content(self, content: str) -> str | None: + """Extract ASCII board visualization from message content. + + Looks for Gomoku-style board patterns like: + 0 1 2 3 4 5 6 7 8 + 0 . . . . . . . . . + 1 . . . . X . . . . + ... + + Returns: + The extracted board string, or None if no board found. + """ + if not content: + return None + + lines = content.split("\n") + board_lines = [] + in_board = False + + for line in lines: + stripped = line.strip() + + # Detect column header line (e.g., " 0 1 2 3 4 5 6 7 8") + if re.match(r"^\s*\d(\s+\d)+\s*$", stripped): + in_board = True + board_lines.append(line) + continue + + # Detect board row lines (e.g., "0 . . . X . . . . ." or "0 . X O . . . . . .") + if in_board and re.match(r"^\s*\d\s+[.XO](\s+[.XO])*\s*$", stripped): + board_lines.append(line) + continue + + # End board detection on non-matching line after we've started + if in_board and stripped and not re.match(r"^\s*\d\s+[.XO]", stripped): + # Check if this is still a valid board line with row number + if not re.match(r"^\s*\d", stripped): + break + + if board_lines: + return "\n".join(board_lines) + + return None \ No newline at end of file diff --git a/opentinker/server/config/actor/actor.yaml b/opentinker/server/config/actor/actor.yaml index 43f576a..b2193b7 100755 --- a/opentinker/server/config/actor/actor.yaml +++ b/opentinker/server/config/actor/actor.yaml @@ -27,7 +27,8 @@ use_dynamic_bsz: false # Max tokens per GPU in one PPO batch; affects gradient accumulation # Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} # oc.select: the default val for ref.log_prob_max_token_len_per_gpu -ppo_max_token_len_per_gpu: 16384 +ppo_max_token_len_per_gpu: 32768 +ppo_infer_max_token_len_per_gpu: 32768 # PPO clip ratio clip_ratio: 0.2 diff --git a/opentinker/server/generic_agent_loop.py b/opentinker/server/generic_agent_loop.py index 3930bbc..bc56696 100755 --- a/opentinker/server/generic_agent_loop.py +++ b/opentinker/server/generic_agent_loop.py @@ -221,6 +221,10 @@ def init_class(cls, config, tokenizer, processor, **kwargs): # Initialize trace saving cls._trace_output_dir = os.environ.get("ROLLOUT_TRACE_DIR", None) if cls._trace_output_dir: + # Create per-job subdirectory to isolate traces from different client tasks + job_id = os.environ.get("ROLLOUT_TRACE_JOB_ID", None) + if job_id: + cls._trace_output_dir = str(Path(cls._trace_output_dir) / f"job_{job_id}") cls._save_traces = True cls._process_id = os.getpid() # Store process ID for unique trace naming Path(cls._trace_output_dir).mkdir(parents=True, exist_ok=True) @@ -263,7 +267,7 @@ def init_class(cls, config, tokenizer, processor, **kwargs): "[GenericAgentLoop] WARNING: Weave not installed (pip install weave)" ) except Exception as e: - print(f"[GenericAgentLoop] WARNING: Failed to init Weave: {e}") + logger.warning(f"Failed to init Weave: {e}") @rollout_trace_op async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: @@ -423,6 +427,14 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu : len(agent_data.prompt_ids) - len(agent_data.response_mask) ] + # Truncate prompt_ids if they exceed prompt_length (left truncate to keep recent context) + # This prevents tensor size mismatches in verl's batch processing when context overflows + if len(prompt_ids) > self.prompt_length: + logger.warning( + f"[GenericAgentLoop] Truncating prompt from {len(prompt_ids)} to {self.prompt_length} tokens" + ) + prompt_ids = prompt_ids[-self.prompt_length :] + # Calculate final reward (sum of all turn scores) # Return 0.0 if no turn scores collected - this prevents fallback to naive reward loop # which expects ground_truth data that gym environments don't provide @@ -518,10 +530,6 @@ async def _handle_generating_state( f"[GenericAgentLoop] Context overflow detected: prompt_len={len(agent_data.prompt_ids)}, " f"total_budget={total_context_budget}. Terminating early to avoid negative max_tokens error." ) - print( - f"[GenericAgentLoop WARNING] Context overflow: prompt_len={len(agent_data.prompt_ids)} + " - f"min_gen={min_generation_tokens} > budget={total_context_budget}. Terminating early." - ) # Add a placeholder response if none exists yet (so we have valid output) if not agent_data.response_ids: # Add EOS token as minimal response @@ -705,13 +713,13 @@ async def _save_debug_images(self, image_data: list, request_id: str): await self.loop.run_in_executor( None, lambda p=img_path, im=img: im.save(str(p)) ) - print(f"[GenericAgentLoop DEBUG] Saved debug image to {img_path}") + logger.debug(f"Saved debug image to {img_path}") else: print( f"[GenericAgentLoop DEBUG] Image {idx} is of type {type(img)}, cannot save" ) except Exception as e: - print(f"[GenericAgentLoop DEBUG] Failed to save image {idx}: {e}") + logger.debug(f"Failed to save image {idx}: {e}") @classmethod def _initialize_interactions(cls, interaction_config_file): @@ -745,14 +753,11 @@ async def _save_rollout_trace( """ try: # Use step + short UUID for trace file naming - # Format: step__ for easy identification by step trace_uuid = request_id[:8] # Use first 8 chars of request_id as short UUID if step is not None: - # Include step number for grouping trace_id = f"step_{step:06d}_{trace_uuid}" else: - # Fallback: just use UUID trace_id = trace_uuid # Decode text for readability @@ -825,7 +830,7 @@ async def _save_rollout_trace( "metrics": agent_data.metrics, } - # Save to file with step and UUID for identification + # Save trace file to per-job directory trace_file = Path(self._trace_output_dir) / f"trace_{trace_id}.json" await self.loop.run_in_executor( None, @@ -836,21 +841,18 @@ async def _save_rollout_trace( # Also append to streaming JSONL file for easy batch processing jsonl_file = Path(self._trace_output_dir) / "traces.jsonl" + json_line = json.dumps(trace_data, default=str) + "\n" await self.loop.run_in_executor( None, - lambda: ( - jsonl_file.open("a").write( - json.dumps(trace_data, default=str) + "\n" - ) - ), + lambda: jsonl_file.open("a").write(json_line), ) - print(f"[GenericAgentLoop] Saved trace {trace_id} to {trace_file}") + logger.info(f"Saved trace {trace_id} to {trace_file}") except Exception as e: import traceback - print(f"[GenericAgentLoop] WARNING: Failed to save rollout trace: {e}") + logger.warning(f"Failed to save rollout trace: {e}") print(traceback.format_exc()) def _extract_per_turn_board_states( diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 809be9a..8d67b10 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -938,11 +938,52 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True, ) + + # Per-turn training expansion: when agent loops expand multi-turn + # episodes into individual per-turn training samples, the gen_batch_output + # batch size is larger than the original batch. We need to expand the + # original batch to match using the expansion index. + expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + if expansion_index is not None: + logger.info( + f"[Per-turn training] Expanding original batch from {len(batch)} to " + f"{len(gen_batch_output)} to match per-turn expanded rollout output" + ) + expansion_index = np.array(expansion_index) + # Expand tensor batch + if batch.batch is not None and len(batch.batch.keys()) > 0: + batch.batch = batch.batch[expansion_index] + elif batch.batch is not None: + # Empty TensorDict (all keys were popped) — create new one with expanded size + from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) + # Expand non-tensor batch + expanded_non_tensor = {} + for k, v in batch.non_tensor_batch.items(): + expanded_non_tensor[k] = v[expansion_index] + batch.non_tensor_batch = expanded_non_tensor + batch = batch.union(gen_batch_output) logger.info( f"DEBUG: batch keys after gen union: {list(batch.batch.keys())}" ) + # 3.1 Per-turn expansion may produce a batch size not divisible by world_size. + # Trim excess samples so downstream partitioning works. + world_size = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + remainder = len(batch) % world_size + if remainder != 0: + trim_to = len(batch) - remainder + logger.info( + f"[Per-turn training] Trimming batch from {len(batch)} to {trim_to} " + f"(divisible by world_size={world_size})" + ) + batch = batch[:trim_to] + # 4. Compute response mask if not present if "response_mask" not in batch.batch.keys(): batch.batch["response_mask"] = compute_response_mask(batch) @@ -1493,6 +1534,24 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: ) # 6. Merge original batch and generated output + # Per-turn training expansion: expand batch if gen output is larger + expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + if expansion_index is not None: + logger.info( + f"[Per-turn training] Validation: Expanding original batch from {len(batch)} to " + f"{len(gen_batch_output)} to match per-turn expanded rollout output" + ) + expansion_index = np.array(expansion_index) + if batch.batch is not None and len(batch.batch.keys()) > 0: + batch.batch = batch.batch[expansion_index] + elif batch.batch is not None: + # Empty TensorDict (all keys were popped) — create new one with expanded size + from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) + expanded_non_tensor = {} + for k, v in batch.non_tensor_batch.items(): + expanded_non_tensor[k] = v[expansion_index] + batch.non_tensor_batch = expanded_non_tensor batch = batch.union(gen_batch_output) # 7. Compute reward using validation reward function diff --git a/opentinker/server/launch_http_server.py b/opentinker/server/launch_http_server.py old mode 100755 new mode 100644 index 6c2561d..830af82 --- a/opentinker/server/launch_http_server.py +++ b/opentinker/server/launch_http_server.py @@ -123,7 +123,7 @@ def main(cfg): cfg.trainer.save_freq = 500 cfg.trainer.test_freq = 500 cfg.trainer.total_epochs = 15 - cfg.trainer.default_local_dir = "/workspace/verl/verl/ckpts" + cfg.trainer.default_local_dir = "./ckpt" # --------------------------------------------------------- # Agent Loop Configuration