diff --git a/.gitignore b/.gitignore index 89b1a0a..86a6697 100755 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,4 @@ logs log outputs .history +**/traces/ diff --git a/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py b/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py index 9f355ee..a8e917a 100755 --- a/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py +++ b/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py @@ -1216,6 +1216,274 @@ def _balance_batch( ) metrics.update(global_balance_stats) + def _compute_wm_beta_token( + self, batch: DataProto, wm_config: dict + ) -> tuple[DataProto, dict]: + """Compute per-token entropy coefficient based on WM uncertainty. + + This computes beta_token at the trainer level (not actor level) to ensure: + 1. Stable weights across PPO epochs (using lagged θ^- from old_log_probs) + 2. Proper coordination with observation_mask and turn_ids + + Formula: + For each turn t: + u_t = mean(-log_prob on obs tokens of turn t) # WM uncertainty + z_t = (u_t - mean(u)) / std(u) # z-score + β_t = β_0 + β_1 * sigmoid(γ * z_t) # dynamic entropy coeff + + beta_token[action_tokens_of_turn_t] = β_t + + Args: + batch: DataProto with old_log_probs, observation_mask, turn_ids, response_mask + wm_config: WM dynamic entropy configuration dict + + Returns: + Updated batch with beta_token field, and metrics dict + """ + import torch + import numpy as np + + metrics = {} + + # Get config - handle both dict and config object access patterns + base_entropy_coeff = self.config.actor_rollout_ref.actor.entropy_coeff + + # Helper function to get value from config (dict or object) + def get_config_value(cfg, key, default): + if isinstance(cfg, dict): + return cfg.get(key, default) + elif hasattr(cfg, key): + return getattr(cfg, key, default) + elif hasattr(cfg, "get"): + return cfg.get(key, default) + return default + + # New per-turn budget design: + # beta_t = per_turn_budget * (1 + fluctuation * tanh(gamma * (u - baseline) / baseline)) + # This gives each turn a base budget, adjusted up/down by WM uncertainty + per_turn_budget = get_config_value(wm_config, "per_turn_budget", 0.002) + baseline = get_config_value(wm_config, "baseline", 2.0) # Expected "normal" uncertainty + gamma = get_config_value(wm_config, "gamma", 1.0) # Sensitivity + fluctuation = get_config_value(wm_config, "fluctuation", 0.5) # ±50% range + + # Legacy config support (fallback to old beta_0/beta_1 if per_turn_budget not set) + beta_0 = get_config_value(wm_config, "beta_0", None) + beta_1 = get_config_value(wm_config, "beta_1", None) + use_legacy_mode = beta_0 is not None and beta_1 is not None and per_turn_budget == 0.002 + + if beta_0 is None: + beta_0 = base_entropy_coeff + if beta_1 is None: + beta_1 = base_entropy_coeff + + # Get data from batch + old_log_probs = batch.batch["old_log_probs"] # (B, L) - lagged policy log_probs + response_mask = batch.batch["response_mask"] # (B, L) + batch_size, response_length = response_mask.shape + device = response_mask.device + + # Get observation_mask and turn_ids from non_tensor_batch + observation_mask_raw = batch.non_tensor_batch.get("observation_mask", None) + turn_ids_raw = batch.non_tensor_batch.get("turn_ids", None) + + # DEBUG: Print available keys in non_tensor_batch + print( + f"[WM Dynamic Entropy] DEBUG: non_tensor_batch keys = {list(batch.non_tensor_batch.keys())}" + ) + print( + f"[WM Dynamic Entropy] DEBUG: observation_mask_raw is None = {observation_mask_raw is None}" + ) + print( + f"[WM Dynamic Entropy] DEBUG: turn_ids_raw is None = {turn_ids_raw is None}" + ) + + if observation_mask_raw is None or turn_ids_raw is None: + print( + "[WM Dynamic Entropy] WARNING: observation_mask or turn_ids not found, using per_turn_budget as neutral" + ) + # Create uniform beta_token with per_turn_budget (neutral value) + beta_token = ( + torch.ones(batch_size, response_length, device=device) + * per_turn_budget + ) + batch.batch["beta_token"] = beta_token + metrics["wm_entropy/beta_token_mean"] = per_turn_budget + metrics["wm_entropy/per_turn_budget"] = per_turn_budget + return batch, metrics + + # Build tensors from raw data + obs_mask_tensor = torch.zeros( + batch_size, response_length, device=device, dtype=torch.bool + ) + turn_ids_tensor = torch.zeros( + batch_size, response_length, device=device, dtype=torch.long + ) + + for b in range(batch_size): + # observation_mask + obs_mask = observation_mask_raw[b] + if obs_mask is not None and len(obs_mask) > 0: + # Verify format: should be 0/1 vector, not index list + obs_arr = np.array(obs_mask[:response_length]) + if len(obs_arr) > 0 and obs_arr.max() > 1: + print( + f"[WM Dynamic Entropy] WARNING: observation_mask looks like index list (max={obs_arr.max()}), skipping sample {b}" + ) + continue + mask_tensor = torch.tensor(obs_arr, device=device, dtype=torch.float32) + if len(mask_tensor) < response_length: + mask_tensor = torch.cat( + [ + mask_tensor, + torch.zeros( + response_length - len(mask_tensor), device=device + ), + ] + ) + obs_mask_tensor[b] = mask_tensor.bool() + + # turn_ids + t_ids = turn_ids_raw[b] + if t_ids is not None and len(t_ids) > 0: + ids_tensor = torch.tensor( + t_ids[:response_length], device=device, dtype=torch.long + ) + if len(ids_tensor) < response_length: + ids_tensor = torch.cat( + [ + ids_tensor, + torch.zeros( + response_length - len(ids_tensor), + device=device, + dtype=torch.long, + ), + ] + ) + turn_ids_tensor[b] = ids_tensor + + # Initialize beta_token with per_turn_budget (neutral value) + # This is used for tokens in turns that have no obs (e.g., first turn) + # where we can't compute WM uncertainty + default_beta = per_turn_budget # Use per_turn_budget as neutral + beta_token = ( + torch.ones(batch_size, response_length, device=device) * default_beta + ) + + # DEBUG: Print sample info + if observation_mask_raw is not None and len(observation_mask_raw) > 0: + print( + f"[WM Dynamic Entropy] DEBUG: First sample observation_mask length = {len(observation_mask_raw[0]) if observation_mask_raw[0] is not None else 'None'}" + ) + print( + f"[WM Dynamic Entropy] DEBUG: First sample turn_ids length = {len(turn_ids_raw[0]) if turn_ids_raw[0] is not None else 'None'}" + ) + + # Track statistics + all_uncertainties = [] + all_betas = [] + valid_samples = 0 + + for b in range(batch_size): + sample_obs_mask = obs_mask_tensor[b] + sample_turn_ids = turn_ids_tensor[b] + sample_log_prob = old_log_probs[b] + + # Compute per-turn WM uncertainty + max_turns = ( + sample_turn_ids.max().item() + 1 if sample_turn_ids.numel() > 0 else 1 + ) + turn_uncertainties = [] + + for t in range(max_turns): + # Find observation tokens for this turn + turn_obs_mask = (sample_turn_ids == t) & sample_obs_mask + if turn_obs_mask.sum() > 0: + # WM uncertainty = mean negative log_prob on obs tokens + # Note: We use old_log_probs which has proper values for obs tokens + # because the full sequence is passed through the model + obs_log_probs = sample_log_prob[turn_obs_mask] + # Filter out zero/invalid log_probs (shouldn't happen with proper data) + valid_log_probs = obs_log_probs[obs_log_probs != 0] + if len(valid_log_probs) > 0: + uncertainty = -valid_log_probs.mean().item() + turn_uncertainties.append((t, uncertainty)) + all_uncertainties.append(uncertainty) + + if len(turn_uncertainties) >= 1: + valid_samples += 1 + + # Compute DYNAMIC baseline = mean of all turn uncertainties in this sample + sample_uncertainties = [u for _, u in turn_uncertainties] + sample_baseline = np.mean(sample_uncertainties) + + for t, u in turn_uncertainties: + # Per-turn budget with dynamic baseline adjustment + # beta_t = per_turn_budget * (1 + fluctuation * tanh(gamma * (u - sample_baseline) / sample_baseline)) + # + # When u = sample_baseline: multiplier = 1, beta_t = per_turn_budget + # When u > sample_baseline: multiplier > 1, beta_t > per_turn_budget (more exploration) + # When u < sample_baseline: multiplier < 1, beta_t < per_turn_budget (less exploration) + + if sample_baseline > 1e-6: + normalized_diff = (u - sample_baseline) / sample_baseline + else: + normalized_diff = 0.0 + + # Use tanh to smoothly bound the multiplier to [1-fluctuation, 1+fluctuation] + multiplier = 1.0 + fluctuation * np.tanh(gamma * normalized_diff) + beta_t = per_turn_budget * multiplier + + # Ensure beta_t is non-negative + beta_t = max(0.0, beta_t) + + all_betas.append(beta_t) + + # Apply to action tokens of this turn (not obs tokens) + turn_action_mask = ( + (sample_turn_ids == t) + & (~sample_obs_mask) + & response_mask[b].bool() + ) + beta_token[b, turn_action_mask] = beta_t + + # Add beta_token to batch + batch.batch["beta_token"] = beta_token + + # Compute metrics + metrics["wm_entropy/per_turn_budget"] = per_turn_budget + metrics["wm_entropy/fluctuation"] = fluctuation + + if all_uncertainties: + # uncertainty_mean is the dynamic baseline (average across all turns in batch) + batch_mean_uncertainty = np.mean(all_uncertainties) + metrics["wm_entropy/uncertainty_mean"] = batch_mean_uncertainty + metrics["wm_entropy/uncertainty_std"] = ( + np.std(all_uncertainties) if len(all_uncertainties) > 1 else 0.0 + ) + metrics["wm_entropy/uncertainty_min"] = np.min(all_uncertainties) + metrics["wm_entropy/uncertainty_max"] = np.max(all_uncertainties) + if all_betas: + metrics["wm_entropy/beta_mean"] = np.mean(all_betas) + metrics["wm_entropy/beta_std"] = ( + np.std(all_betas) if len(all_betas) > 1 else 0.0 + ) + metrics["wm_entropy/beta_min"] = np.min(all_betas) + metrics["wm_entropy/beta_max"] = np.max(all_betas) + # Beta relative to per_turn_budget + metrics["wm_entropy/beta_mean_ratio"] = np.mean(all_betas) / per_turn_budget if per_turn_budget > 0 else 1.0 + metrics["wm_entropy/valid_samples"] = valid_samples + metrics["wm_entropy/total_samples"] = batch_size + + # Log overall beta_token stats + action_mask = response_mask.bool() & (~obs_mask_tensor) + if action_mask.sum() > 0: + metrics["wm_entropy/beta_token_mean"] = ( + beta_token[action_mask].mean().item() + ) + metrics["wm_entropy/beta_token_std"] = beta_token[action_mask].std().item() + + return batch, metrics + def compute_rollout_importance_weights_and_add_to_batch( self, batch: DataProto ) -> tuple[DataProto, dict]: @@ -1618,6 +1886,52 @@ def fit(self): config=self.config.algorithm, ) + # ===================================================================== + # WM-Guided Dynamic Entropy: Compute beta_token at trainer level + # This ensures stable weights across PPO epochs (using lagged θ^-) + # ===================================================================== + actor_config = self.config.actor_rollout_ref.actor + # Handle both dict and dataclass/OmegaConf access patterns + if hasattr(actor_config, "get"): + wm_dynamic_entropy_config = actor_config.get( + "wm_dynamic_entropy", {} + ) + elif hasattr(actor_config, "wm_dynamic_entropy"): + wm_dynamic_entropy_config = actor_config.wm_dynamic_entropy + # Convert to dict if it's a config object + if hasattr(wm_dynamic_entropy_config, "__dict__"): + wm_dynamic_entropy_config = dict( + wm_dynamic_entropy_config + ) + elif hasattr(wm_dynamic_entropy_config, "items"): + wm_dynamic_entropy_config = dict( + wm_dynamic_entropy_config + ) + else: + wm_dynamic_entropy_config = {} + + # Check if enabled + is_enabled = False + if isinstance(wm_dynamic_entropy_config, dict): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + elif hasattr(wm_dynamic_entropy_config, "enabled"): + is_enabled = wm_dynamic_entropy_config.enabled + elif hasattr(wm_dynamic_entropy_config, "get"): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + + print( + f"[WM Dynamic Entropy] DEBUG: is_enabled = {is_enabled}, config type = {type(wm_dynamic_entropy_config)}" + ) + + if is_enabled: + with marked_timer( + "wm_beta_token", timing_raw, color="magenta" + ): + batch, wm_metrics = self._compute_wm_beta_token( + batch, wm_dynamic_entropy_config + ) + metrics.update(wm_metrics) + # update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): diff --git a/opentinker/client/alfworld_inference.py b/opentinker/client/alfworld_inference.py new file mode 100644 index 0000000..4b87092 --- /dev/null +++ b/opentinker/client/alfworld_inference.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +"""ALFWorld Inference Script. + +This script runs inference/evaluation on trained ALFWorld models. + +Usage: + # Start ALFWorld environment server first (in another terminal): + python -m opentinker.environment.alfworld.alfworld_server --port 8091 --split eval_in_distribution + + # Run inference with scheduler: + python alfworld_inference.py \ + model_path=/path/to/checkpoint \ + scheduler_url=http://localhost:8089 \ + data_path=/path/to/eval_data.jsonl +""" + +import hydra + +from utils.http_training_client import InferenceSchedulerClient +from utils.scheduler_client_lifecycle import get_lifecycle_manager +from opentinker.environment.inference_pipeline import run_inference +from opentinker.environment.alfworld import ALFWorldGame +from opentinker.environment.game_stats_client import GameStatsClient + + +@hydra.main( + config_path="client_config", + config_name="alfworld_inference_config.yaml", + version_base=None, +) +def main(args): + """Run ALFWorld inference with scheduler-managed vLLM server.""" + lifecycle = get_lifecycle_manager() + + print("=" * 60) + print("ALFWorld Inference with Scheduler") + print("=" * 60) + + if not args.model_path: + raise ValueError("model_path is required") + + # 1. Submit inference job to scheduler + scheduler_client = InferenceSchedulerClient( + scheduler_url=args.get("scheduler_url", "http://localhost:8089"), + api_key=args.get("scheduler_api_key"), + ) + + print(f"\nModel: {args.model_path}") + print(f"Scheduler: {args.scheduler_url}") + print(f"Environment: {args.env_endpoint}") + print(f"Split: {args.split}") + + print("\nSubmitting inference job to scheduler...") + job_result = scheduler_client.submit_inference_job( + model_path=args.model_path, + tokenizer_path=args.get("tokenizer_path"), + tensor_parallel_size=args.get("tensor_parallel_size", 1), + num_gpus=args.get("num_gpus"), + gpu_memory_utilization=args.get("gpu_memory_utilization", 0.9), + max_model_len=args.get("max_model_len"), + trust_remote_code=args.get("trust_remote_code", True), + ) + + job_id = job_result["job_id"] + vllm_server_url = job_result["vllm_server_url"] + + # Register job for lifecycle cleanup + lifecycle.register_job(scheduler_client, job_id) + + print(f"✓ Inference job {job_id} started at {vllm_server_url}") + + # 2. Setup GameStatsClient for per-step metrics (with job_id isolation) + game_stats = GameStatsClient(args.env_endpoint, job_id=job_id) + if game_stats.health_check(): + print(f"✓ Connected to ALFWorld server at {args.env_endpoint}") + game_stats.reset_all() # Reset stats for this job before inference + else: + print( + f"⚠ ALFWorld server not available at {args.env_endpoint}, continuing without stats" + ) + game_stats = None + + # 3. Run inference using the remote vLLM server + data_path = args.get("data_path") + if data_path: + print(f"Running inference on {data_path}...") + else: + print(f"Running inference on ALFWorld {args.split} split...") + + results = run_inference( + model_path=None, # Not needed when using vllm_server_url + vllm_server_url=vllm_server_url, + tokenizer_path=args.get("tokenizer_path") or args.model_path, + data_path=data_path, + game_class=ALFWorldGame, + env_endpoint=args.env_endpoint, + job_id=job_id, # Pass job_id for stats isolation + output_path=args.get("output_path"), + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_new_tokens, + max_samples=args.get("max_samples"), + max_user_turns=args.multi_turn.max_user_turns, + max_assistant_turns=args.multi_turn.max_assistant_turns, + ) + + # 4. Log game stats after inference + print("\n" + "=" * 60) + print("Inference Results") + print("=" * 60) + + if game_stats: + stats = game_stats.get_all_stats() + print(f"\nALFWorld Evaluation Stats (job_id={job_id}):") + print(f" Total episodes: {stats.get('total_games', 0)}") + print(f" Successes: {stats.get('total_wins', 0)}") + print(f" Failures: {stats.get('total_losses', 0)}") + success_rate = stats.get("cumulative_win_rate", 0) + print(f" Success rate: {success_rate:.1%}") + print(f" Mean reward: {stats.get('mean_final_reward', 0):.4f}") + print(f" Mean steps: {stats.get('mean_steps', 0):.2f}") + + if results: + print(f"\nProcessed {len(results)} samples") + + if args.get("output_path"): + print(f"Results saved to: {args.output_path}") + + print(f"\n{'='*60}") + print("Inference completed! vLLM server will be automatically cleaned up.") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/opentinker/client/client_config/alfworld_inference_config.yaml b/opentinker/client/client_config/alfworld_inference_config.yaml new file mode 100644 index 0000000..64c39cf --- /dev/null +++ b/opentinker/client/client_config/alfworld_inference_config.yaml @@ -0,0 +1,37 @@ +# ALFWorld Inference Configuration +# Use with: python alfworld_inference.py + +# Model settings +model_path: null # Path to trained checkpoint (HuggingFace format) - REQUIRED +tokenizer_path: null # Tokenizer path (defaults to model_path if null) + +# GPU settings +tensor_parallel_size: 1 # Number of GPUs for tensor parallelism +num_gpus: 1 # Number of GPUs to request from scheduler +gpu_memory_utilization: 0.9 +max_model_len: null # Max model context length (null = auto) +trust_remote_code: true + +# Generation parameters (greedy by default for inference) +temperature: 0.0 # 0.0 = greedy decoding for deterministic evaluation +top_p: 1.0 +max_new_tokens: 4096 # Max tokens for full multi-turn trajectory + +# Data settings +data_path: null # Input data file (parquet/jsonl), null = use ALFWorld split +output_path: null # Output results file (jsonl) +max_samples: null # Limit samples (null = all) + +# Environment settings +env_endpoint: http://0.0.0.0:8091 +split: eval_in_distribution # train, eval_in_distribution, eval_out_of_distribution + +# Multi-turn settings for ALFWorld +multi_turn: + max_user_turns: 50 # Max environment interactions + max_assistant_turns: 50 + max_tokens_per_turn: 256 # Per-turn response limit + +# Scheduler settings +scheduler_url: http://0.0.0.0:8089 +scheduler_api_key: null diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 822f999..74d2416 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -8,8 +8,8 @@ experiment_name: alfworld_training # Logging logger_backends: ["console", "wandb"] -# Tracing (optional) -enable_tracing: true +# Tracing (optional) - DISABLED to prevent disk space issues +enable_tracing: false weave_project: null # WandB (optional) @@ -24,8 +24,8 @@ num_workers: 4 # Training duration - set ONE of these (num_steps takes precedence if both set) num_epochs: null # Number of epochs (null = use num_steps) num_steps: 1000 # Total training steps (null = use num_epochs) -save_freq: 20000 -test_freq: 10 # Validation frequency (every N steps) +save_freq: 10000 +test_freq: 100 # Validation frequency (every N steps) # Validation parameters val_batch_size: 50 # Total validation samples (null = 50) @@ -81,3 +81,27 @@ scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa # GPU settings num_gpus: 4 + +# Actor settings (passed to server) +actor: + # World model loss: predict environment observations as auxiliary task + # 训练模型预测环境观察,提供 WM 不确定性信号 + use_world_model_loss: true + world_model_loss_coef: 0.01 # 用小系数避免干扰 policy + + + # Turn-wise Dynamic Entropy Coefficient (WM-guided) + # 根据每个 turn 的 WM uncertainty 调整 entropy bonus + # 高 uncertainty turn -> β > per_turn_budget -> 更多探索 + # 低 uncertainty turn -> β < per_turn_budget -> 更稳定执行 + wm_dynamic_entropy: + enabled: true + # Per-turn budget design: + # β_t = per_turn_budget * (1 + fluctuation * tanh(γ * (u - baseline) / baseline)) + # - per_turn_budget: 每个 turn 的基础 entropy budget + # - baseline: 动态计算 = 该 sample 内所有 turn 的 uncertainty 均值 + # - fluctuation: 浮动范围 (0.5 = ±50%) + # - gamma: 敏感度 (越大对 uncertainty 差异越敏感) + per_turn_budget: 0.002 # 每个 turn 的基础 budget + fluctuation: 0.5 # ±50% 浮动 (beta 范围: [0.001, 0.003]) + gamma: 1.0 # 敏感度 diff --git a/opentinker/client/client_config/llm_user_param.yaml b/opentinker/client/client_config/llm_user_param.yaml new file mode 100644 index 0000000..b05a993 --- /dev/null +++ b/opentinker/client/client_config/llm_user_param.yaml @@ -0,0 +1,57 @@ +# LLM User Simulator Training Configuration +# Train a conversational agent with LLM-based user simulation + +# Project settings +project_name: opentinker +experiment_name: llm_user_training + +# Logging +logger_backends: ["console", "wandb"] +enable_tracing: false +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_steps: 1000 +save_freq: 500 +test_freq: 10 +val_batch_size: 20 + +# Generation parameters +temperature: 0.8 +top_p: 0.95 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm +algorithm: "agent_loop" +adv_estimator: "grpo" +rollout_n: 4 + +# Interaction configuration +interaction: + name: llm_user_simulator + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8100 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 8 + max_steps: 10 + observation_template: "{observation}" + +multi_turn: + max_user_turns: 10 + max_assistant_turns: 10 + max_tokens_per_turn: 512 + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/llm_user_rl.py b/opentinker/client/llm_user_rl.py new file mode 100644 index 0000000..7bf1800 --- /dev/null +++ b/opentinker/client/llm_user_rl.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +"""LLM User Simulator RL Training Client. + +Train a conversational agent with LLM-based user simulation. + +Usage: + # Start the LLM user simulator server first: + python -m opentinker.environment.llm_user_simulator.llm_user_server --port 8100 --shards 8 + + # Run training: + python llm_user_rl.py scheduler_url=http://localhost:8780 num_gpus=4 +""" + +from omegaconf import OmegaConf +import hydra + +from utils.http_training_client import ServiceClient, SchedulerClient +from opentinker.environment.base_game_environment import GameEnvironment +from opentinker.environment.llm_user_simulator import LLMUserGame +from opentinker.environment.game_stats_client import GameStatsClient +from utils.utils import resolve_paths_in_config +from utils.scheduler_client_lifecycle import get_lifecycle_manager + + +@hydra.main(config_path="client_config", config_name="llm_user_param.yaml") +def main(args): + args = resolve_paths_in_config(args) + lifecycle = get_lifecycle_manager() + + print("=" * 60) + print("Training with LLM User Simulator") + print("=" * 60) + + # Connect to scheduler + scheduler_url = args.get("scheduler_url", "http://localhost:8780") + scheduler_api_key = args.get("scheduler_api_key", None) + + print(f"\nConnecting to scheduler at {scheduler_url}") + scheduler_client = SchedulerClient( + scheduler_url=scheduler_url, api_key=scheduler_api_key + ) + + # Submit job + print("\nSubmitting training job...") + job_result = scheduler_client.submit_job( + config=OmegaConf.to_container(args, resolve=True), + enable_agent_loop=True, + wandb_key=args.get("wandb_key"), + num_gpus=args.get("num_gpus"), + ) + + job_id = job_result["job_id"] + server_url = job_result["server_url"] + lifecycle.register_job(scheduler_client, job_id) + + print(f"\n✓ Job {job_id} allocated!") + print(f" Server URL: {server_url}") + print("=" * 60) + + # Setup GameEnvironment + interaction_config = args.interaction.config + game_kwargs = { + "max_turns": interaction_config.get("max_steps", 10), + } + + env = GameEnvironment( + game_class=LLMUserGame, + config=args, + game_kwargs=game_kwargs, + job_id=job_id, + ) + + # Setup stats client + env_endpoint = interaction_config.env_endpoint + game_stats = GameStatsClient(env_endpoint, job_id=env.job_id) + if game_stats.health_check(): + print(f"✓ Connected to LLM user simulator at {env_endpoint}") + game_stats.reset_all() + else: + print(f"⚠ Server at {env_endpoint} not responding") + game_stats = None + + # Connect to training server + print(f"\nConnecting to server at {server_url}") + client = ServiceClient( + server_url=server_url, + project_name=args.project_name, + experiment_name=args.experiment_name, + logger_backends=args.logger_backends, + ) + + client.set_config(args, env) + + # Train + num_steps = args.get("num_steps", 1000) + print(f"\nStarting training for {num_steps} steps...") + print("=" * 60) + + try: + final_metrics = client.fit( + env=env, + num_steps=num_steps, + save_freq=args.save_freq, + test_freq=args.test_freq, + verbose=True, + validate_before_training=True, + game_stats_client=game_stats, + ) + + print("\n" + "=" * 60) + print("Training completed!") + print(f"Final metrics: {final_metrics}") + print("=" * 60) + + finally: + env.cleanup() + + +if __name__ == "__main__": + main() diff --git a/opentinker/environment/alfworld/alfworld_server.py b/opentinker/environment/alfworld/alfworld_server.py index 69d59f1..2d2197a 100644 --- a/opentinker/environment/alfworld/alfworld_server.py +++ b/opentinker/environment/alfworld/alfworld_server.py @@ -49,7 +49,7 @@ def main(): parser.add_argument( "--split", type=str, - default="train", + default="eval_in_distribution", choices=["train", "eval_in_distribution", "eval_out_of_distribution"], help="Dataset split to use", ) diff --git a/opentinker/environment/llm_user_simulator/__init__.py b/opentinker/environment/llm_user_simulator/__init__.py new file mode 100644 index 0000000..0e07e66 --- /dev/null +++ b/opentinker/environment/llm_user_simulator/__init__.py @@ -0,0 +1,9 @@ +"""LLM User Simulator Environment. + +This module provides an environment where an LLM simulates a user, +enabling training of conversational agents. +""" + +from opentinker.environment.llm_user_simulator.llm_user_game import LLMUserGame + +__all__ = ["LLMUserGame"] diff --git a/opentinker/environment/llm_user_simulator/llm_user_game.py b/opentinker/environment/llm_user_simulator/llm_user_game.py new file mode 100644 index 0000000..00f0b15 --- /dev/null +++ b/opentinker/environment/llm_user_simulator/llm_user_game.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python3 +"""LLM User Simulator Game Implementation. + +This module provides an environment where an LLM acts as a user simulator, +enabling training of conversational agents through self-play or cross-play. + +Example: + from llm_user_game import LLMUserGame + + game = LLMUserGame( + simulator_model="gpt-4o-mini", + task_prompt="You are a customer trying to book a flight.", + ) + obs = game.reset() + result = game.step("Hello! How can I help you today?") +""" + +import os +import random +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from opentinker.environment.base_game import AbstractGame, StepResult + +# Try to import LLM clients +try: + from openai import OpenAI + + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + +try: + import anthropic + + ANTHROPIC_AVAILABLE = True +except ImportError: + ANTHROPIC_AVAILABLE = False + + +@dataclass +class ConversationTurn: + """A single turn in the conversation.""" + + role: str # "agent" or "user" + content: str + + +class LLMUserGame(AbstractGame): + """LLM-based user simulator environment. + + The agent (being trained) plays the role of an assistant/agent, + while an LLM simulates the user providing requests and feedback. + + Attributes: + simulator_model: Model name for the user simulator (e.g., "gpt-4o-mini") + task_prompt: System prompt defining the user's persona and task + max_turns: Maximum conversation turns before episode ends + success_keywords: Keywords that indicate task success + """ + + # Reward constants + REWARD_SUCCESS = 10.0 + REWARD_FAILURE = -1.0 + REWARD_STEP = -0.01 + REWARD_USER_SATISFIED = 5.0 + + # Default max turns + DEFAULT_MAX_TURNS = 10 + + # LLM Judge prompt + JUDGE_PROMPT = """You are an expert evaluator assessing conversational AI quality. + +Evaluate the following conversation between an AI assistant and a user. + +## Evaluation Criteria (score 1-10 each): + +1. **Helpfulness**: Did the assistant understand and address the user's needs? +2. **Clarity**: Were the responses clear and easy to understand? +3. **Problem Resolution**: Was the user's issue/request successfully resolved? +4. **Professionalism**: Was the tone appropriate and professional? +5. **Efficiency**: Did the assistant resolve the issue without unnecessary back-and-forth? + +## Conversation: +{conversation} + +## Your Evaluation: +Provide scores and a brief explanation in this exact JSON format: +```json +{ + "helpfulness": <1-10>, + "clarity": <1-10>, + "problem_resolution": <1-10>, + "professionalism": <1-10>, + "efficiency": <1-10>, + "overall_score": <1-10>, + "success": , + "explanation": "" +} +``` +""" + + # Default task prompts for various scenarios + TASK_PROMPTS = { + "customer_service": """You are a customer contacting customer service. +You have a specific problem that needs to be resolved. +Be realistic - ask clarifying questions, express frustration if not helped properly. +When your issue is resolved satisfactorily, say "Thank you, that resolves my issue." +If the agent is unhelpful after several attempts, say "This is not helpful, goodbye." +""", + "booking_assistant": """You are a user trying to book a reservation. +You have specific preferences (date, time, number of people, etc.). +Ask questions about availability and options. +When you successfully complete a booking, say "Great, the booking is confirmed." +If unable to book, say "I'll try elsewhere, thanks." +""", + "tech_support": """You are a user with a technical problem. +Describe your issue and provide details when asked. +If the solution works, say "That fixed it, thank you!" +If multiple attempts fail, say "This still doesn't work." +""", + "information_seeking": """You are a user looking for specific information. +Ask questions to get the information you need. +When you get satisfactory answers, say "That's exactly what I needed, thanks!" +If answers are unclear or wrong, say "That's not what I was looking for." +""", + } + + def __init__( + self, + simulator_model: str = "gpt-4o-mini", + simulator_api_key: Optional[str] = None, + simulator_base_url: Optional[str] = None, + task_prompt: Optional[str] = None, + task_type: str = "customer_service", + max_turns: int = DEFAULT_MAX_TURNS, + success_keywords: Optional[List[str]] = None, + failure_keywords: Optional[List[str]] = None, + temperature: float = 0.7, + seed: Optional[int] = None, + use_llm_judge: bool = True, + judge_model: Optional[str] = None, + ): + """Initialize LLM User Simulator. + + Args: + simulator_model: Model name for user simulation + simulator_api_key: API key (defaults to env var) + simulator_base_url: Custom API base URL (for local models) + task_prompt: Custom system prompt for user persona + task_type: Predefined task type if task_prompt not provided + max_turns: Maximum conversation turns + success_keywords: Phrases indicating success (fallback if LLM judge disabled) + failure_keywords: Phrases indicating failure (fallback if LLM judge disabled) + temperature: Sampling temperature for user LLM + seed: Random seed for reproducibility + use_llm_judge: Use LLM-as-a-Judge for evaluation (recommended) + judge_model: Model for judging (defaults to simulator_model) + """ + self.simulator_model = simulator_model + self.max_turns = max_turns + self.temperature = temperature + + # Set API key + self.api_key = simulator_api_key or os.environ.get("OPENAI_API_KEY") + self.base_url = simulator_base_url + + # Initialize LLM client + self._init_llm_client() + + # Set task prompt + if task_prompt: + self.task_prompt = task_prompt + else: + self.task_prompt = self.TASK_PROMPTS.get( + task_type, self.TASK_PROMPTS["customer_service"] + ) + + # LLM-as-a-Judge settings + self.use_llm_judge = use_llm_judge + self.judge_model = judge_model or simulator_model + + # Success/failure detection (fallback when LLM judge is disabled) + self.success_keywords = success_keywords or [ + "thank you", + "that resolves", + "that fixed it", + "exactly what I needed", + "booking is confirmed", + "issue is resolved", + "problem solved", + ] + self.failure_keywords = failure_keywords or [ + "not helpful", + "goodbye", + "doesn't work", + "not what I was looking for", + "try elsewhere", + "give up", + "frustrated", + ] + + # Judge evaluation result (populated at end of episode) + self._judge_result: Optional[Dict[str, Any]] = None + + # Game state + self._conversation: List[ConversationTurn] = [] + self._turn_count = 0 + self._done = False + self._success = False + self._current_task = "" + + if seed is not None: + random.seed(seed) + + def _init_llm_client(self): + """Initialize the LLM client for user simulation.""" + if not OPENAI_AVAILABLE: + raise ImportError( + "openai package not installed. Install with: pip install openai" + ) + + client_kwargs = {"api_key": self.api_key} + if self.base_url: + client_kwargs["base_url"] = self.base_url + + self._client = OpenAI(**client_kwargs) + + def _generate_user_response(self, agent_message: str) -> str: + """Generate user response using the simulator LLM.""" + # Build conversation history for the user LLM + messages = [ + { + "role": "system", + "content": self.task_prompt + "\n\n" + self._current_task, + } + ] + + # Add conversation history + for turn in self._conversation: + if turn.role == "agent": + # Agent messages appear as "assistant" to the user simulator + messages.append({"role": "user", "content": turn.content}) + else: + # User's own previous messages + messages.append({"role": "assistant", "content": turn.content}) + + # Add the latest agent message + messages.append({"role": "user", "content": agent_message}) + + # Generate user response + response = self._client.chat.completions.create( + model=self.simulator_model, + messages=messages, + temperature=self.temperature, + max_tokens=500, + ) + + return response.choices[0].message.content + + def _generate_initial_user_message(self) -> str: + """Generate the initial user message to start conversation.""" + messages = [ + { + "role": "system", + "content": self.task_prompt + "\n\n" + self._current_task, + }, + { + "role": "user", + "content": "Start the conversation by stating your request or problem.", + }, + ] + + response = self._client.chat.completions.create( + model=self.simulator_model, + messages=messages, + temperature=self.temperature, + max_tokens=300, + ) + + return response.choices[0].message.content + + def _check_success(self, text: str) -> bool: + """Check if conversation indicates success.""" + text_lower = text.lower() + return any(kw.lower() in text_lower for kw in self.success_keywords) + + def _check_failure(self, text: str) -> bool: + """Check if conversation indicates failure.""" + text_lower = text.lower() + return any(kw.lower() in text_lower for kw in self.failure_keywords) + + def _evaluate_with_llm_judge(self) -> Dict[str, Any]: + """Evaluate the conversation using LLM-as-a-Judge. + + Returns: + Dictionary with scores and evaluation details. + """ + import json + + # Format conversation for judge + conv_text = "" + for turn in self._conversation: + role_label = "Assistant" if turn.role == "agent" else "User" + conv_text += f"{role_label}: {turn.content}\n\n" + + prompt = self.JUDGE_PROMPT.format(conversation=conv_text) + + try: + response = self._client.chat.completions.create( + model=self.judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, # Low temperature for consistent evaluation + max_tokens=500, + ) + + result_text = response.choices[0].message.content + + # Extract JSON from response + json_match = re.search(r"```json\s*(.*?)\s*```", result_text, re.DOTALL) + if json_match: + result = json.loads(json_match.group(1)) + else: + # Try to parse the whole response as JSON + result = json.loads(result_text) + + # Validate and normalize scores + for key in [ + "helpfulness", + "clarity", + "problem_resolution", + "professionalism", + "efficiency", + "overall_score", + ]: + if key in result: + result[key] = max(1, min(10, int(result[key]))) + + return result + + except Exception as e: + # Fallback to keyword-based evaluation + print(f"[LLM Judge] Evaluation failed: {e}, using fallback") + last_user_msg = self._conversation[-1].content if self._conversation else "" + success = self._check_success(last_user_msg) + return { + "helpfulness": 7 if success else 3, + "clarity": 5, + "problem_resolution": 8 if success else 2, + "professionalism": 5, + "efficiency": 5, + "overall_score": 7 if success else 3, + "success": success, + "explanation": "Fallback evaluation (LLM judge failed)", + } + + def _calculate_reward_from_judge(self, judge_result: Dict[str, Any]) -> float: + """Calculate reward from LLM judge evaluation. + + Maps the overall_score (1-10) to reward range. + """ + overall = judge_result.get("overall_score", 5) + success = judge_result.get("success", False) + + if success: + # Success: reward based on quality (5-10 range) + # Map overall_score 1-10 to reward 5-10 + return 5.0 + (overall / 10.0) * 5.0 + else: + # Failure: negative reward based on how bad + # Map overall_score 1-10 to reward -5 to 0 + return (overall / 10.0) * 5.0 - 5.0 + + def reset( + self, task_prompt: Optional[str] = None, seed: Optional[int] = None, **kwargs + ) -> str: + """Reset the game to start a new conversation. + + Args: + task_prompt: Override task prompt for this episode + seed: Random seed + **kwargs: Additional arguments + + Returns: + Initial observation (user's first message) + """ + if seed is not None: + random.seed(seed) + + # Update task prompt if provided + if task_prompt: + self._current_task = task_prompt + else: + # Generate a random specific task + self._current_task = self._generate_random_task() + + # Reset state + self._conversation = [] + self._turn_count = 0 + self._done = False + self._success = False + self._judge_result = None + + # Generate initial user message + initial_message = self._generate_initial_user_message() + self._conversation.append( + ConversationTurn(role="user", content=initial_message) + ) + + return self._format_observation(initial_message) + + def _generate_random_task(self) -> str: + """Generate a random specific task for variety.""" + tasks = [ + "Your flight was cancelled and you need to rebook.", + "You received a defective product and want a refund.", + "You need to change your hotel reservation dates.", + "Your internet connection is not working.", + "You want to upgrade your subscription plan.", + "You're looking for recommendations for a restaurant.", + "You need help resetting your password.", + "You want to cancel your order.", + ] + return random.choice(tasks) + + def _format_observation(self, message: str) -> str: + """Format observation for the agent.""" + obs = f"=== User Message ===\n{message}\n" + obs += f"\n=== Conversation Turn: {self._turn_count + 1}/{self.max_turns} ===" + return obs + + def step(self, action: str) -> StepResult: + """Execute agent action and get user response. + + Args: + action: Agent's response to the user + + Returns: + StepResult with user's response, reward, done flag, and info + """ + if self._done: + return StepResult( + observation="Conversation has ended.", + reward=0.0, + done=True, + info={"error": "conversation_ended"}, + ) + + self._turn_count += 1 + + # Parse agent's action + parsed_action = self._parse_action(action) + + # Add agent message to conversation + self._conversation.append(ConversationTurn(role="agent", content=parsed_action)) + + # Generate user response + user_response = self._generate_user_response(parsed_action) + self._conversation.append(ConversationTurn(role="user", content=user_response)) + + # Check for episode end conditions + episode_ended = False + end_reason = "" + + if self._check_success(user_response): + episode_ended = True + end_reason = "success_keyword" + elif self._check_failure(user_response): + episode_ended = True + end_reason = "failure_keyword" + elif self._turn_count >= self.max_turns: + episode_ended = True + end_reason = "timeout" + + # Calculate reward + if episode_ended: + self._done = True + + if self.use_llm_judge: + # Use LLM-as-a-Judge for evaluation + self._judge_result = self._evaluate_with_llm_judge() + reward = self._calculate_reward_from_judge(self._judge_result) + self._success = self._judge_result.get("success", False) + + # Add evaluation summary to response + judge_summary = ( + f"\n\n=== LLM Judge Evaluation ===\n" + f"Overall Score: {self._judge_result.get('overall_score', 'N/A')}/10\n" + f"Success: {self._success}\n" + f"Explanation: {self._judge_result.get('explanation', 'N/A')}" + ) + user_response = f"{user_response}{judge_summary}" + else: + # Fallback to keyword-based evaluation + if end_reason == "success_keyword": + self._success = True + reward = self.REWARD_SUCCESS + else: + self._success = False + reward = self.REWARD_FAILURE + + # Add end reason prefix + if end_reason == "timeout": + user_response = f"TIMEOUT: Maximum turns reached.\n\n{user_response}" + elif self._success: + user_response = f"SUCCESS: {user_response}" + else: + user_response = f"FAILURE: {user_response}" + else: + reward = self.REWARD_STEP + + # Build info dict + info = { + "turn": self._turn_count, + "success": self._success, + "agent_message": parsed_action, + "user_message": user_response, + } + + # Add judge evaluation if available + if self._judge_result: + info["judge_evaluation"] = self._judge_result + + return StepResult( + observation=self._format_observation(user_response), + reward=reward, + done=self._done, + info=info, + ) + + def _parse_action(self, raw_action: str) -> str: + """Parse action from LLM output.""" + # Try to extract from tags + match = re.search( + r"\s*(.*?)\s*", raw_action, re.IGNORECASE | re.DOTALL + ) + if match: + return match.group(1).strip() + + # Otherwise use the whole output + return raw_action.strip() + + def get_system_prompt(self) -> str: + """Return the system prompt for the agent.""" + return ( + "You are a helpful assistant engaging in a conversation with a user.\n" + "Your goal is to understand the user's needs and help them effectively.\n\n" + "IMPORTANT: Respond naturally and helpfully. Be concise but thorough.\n" + "If you need more information, ask clarifying questions.\n" + "If you can help, provide clear solutions or information.\n\n" + "Wrap your response in tags.\n\n" + "Example:\n" + "I understand you're having trouble with your order. " + "Could you please provide your order number so I can look into this?" + ) + + def get_initial_user_message(self) -> str: + """Return context for the agent.""" + return "You are helping a user. Respond to their message." + + def get_state(self) -> Dict[str, Any]: + """Return current game state.""" + state = { + "turn_count": self._turn_count, + "max_turns": self.max_turns, + "done": self._done, + "success": self._success, + "conversation_length": len(self._conversation), + "use_llm_judge": self.use_llm_judge, + } + if self._judge_result: + state["judge_evaluation"] = self._judge_result + return state + + def generate_initial_state(self) -> Dict[str, Any]: + """Generate random initial state for training.""" + return { + "seed": random.randint(0, 1000000), + } + + def get_user_message_with_state(self, **kwargs) -> str: + """Generate user message with state for prompt.""" + self.reset(**kwargs) + initial_obs = self._format_observation(self._conversation[0].content) + return f"{initial_obs}\n\nRespond to the user." + + def get_interaction_name(self) -> str: + """Return interaction name.""" + return "llm_user_simulator" diff --git a/opentinker/environment/llm_user_simulator/llm_user_server.py b/opentinker/environment/llm_user_simulator/llm_user_server.py new file mode 100644 index 0000000..0fdb687 --- /dev/null +++ b/opentinker/environment/llm_user_simulator/llm_user_server.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +"""LLM User Simulator Server. + +This script starts an LLM user simulator server. + +Usage: + python llm_user_server.py --port 8100 --shards 8 + + # With custom model: + python llm_user_server.py --port 8100 --simulator_model gpt-4o-mini +""" + +import argparse +import os +import subprocess +import sys +import time + + +def main(): + parser = argparse.ArgumentParser(description="LLM User Simulator Server") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8100, help="Server port") + parser.add_argument( + "--shards", + type=int, + default=4, + help="Number of independent server processes on consecutive ports.", + ) + parser.add_argument( + "--simulator_model", + type=str, + default="gpt-4o-mini", + help="Model name for user simulation (e.g., gpt-4o-mini, gpt-4o)", + ) + parser.add_argument( + "--simulator_base_url", + type=str, + default=None, + help="Custom API base URL (for local models like vLLM)", + ) + parser.add_argument( + "--task_type", + type=str, + default="customer_service", + choices=[ + "customer_service", + "booking_assistant", + "tech_support", + "information_seeking", + ], + help="Type of user simulation task", + ) + parser.add_argument( + "--max_turns", + type=int, + default=10, + help="Maximum conversation turns per episode", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature for user simulator", + ) + parser.add_argument( + "--use_llm_judge", + action="store_true", + default=True, + help="Use LLM-as-a-Judge for evaluation (default: True)", + ) + parser.add_argument( + "--no_llm_judge", + action="store_true", + help="Disable LLM-as-a-Judge, use keyword matching instead", + ) + parser.add_argument( + "--judge_model", + type=str, + default=None, + help="Model for LLM judge (defaults to simulator_model)", + ) + args = parser.parse_args() + + # Handle --no_llm_judge flag + if args.no_llm_judge: + args.use_llm_judge = False + + from opentinker.environment.llm_user_simulator.llm_user_game import LLMUserGame + + print("\nLLM User Simulator Configuration:") + print(f" Simulator model: {args.simulator_model}") + print(f" Base URL: {args.simulator_base_url or 'default (OpenAI)'}") + print(f" Task type: {args.task_type}") + print(f" Max turns: {args.max_turns}") + print(f" Temperature: {args.temperature}") + print(f" Shards: {args.shards}") + print(f" LLM-as-a-Judge: {'enabled' if args.use_llm_judge else 'disabled'}") + if args.use_llm_judge: + print(f" Judge model: {args.judge_model or args.simulator_model}") + print("\nReward structure:") + if args.use_llm_judge: + print(" Using LLM Judge scoring (1-10 scale mapped to rewards)") + else: + print(f" Success: +{LLMUserGame.REWARD_SUCCESS}") + print(f" Failure: {LLMUserGame.REWARD_FAILURE}") + print(f" Step penalty: {LLMUserGame.REWARD_STEP}") + + # Sharded mode + if args.shards and args.shards > 1: + print( + f"\nStarting sharded mode: {args.shards} shards on ports {args.port}..{args.port + args.shards - 1}" + ) + + children: list[subprocess.Popen] = [] + try: + for i in range(args.shards): + port_i = args.port + i + cmd = [ + sys.executable, + os.path.abspath(__file__), + "--host", + args.host, + "--port", + str(port_i), + "--shards", + "1", + "--simulator_model", + args.simulator_model, + "--task_type", + args.task_type, + "--max_turns", + str(args.max_turns), + "--temperature", + str(args.temperature), + ] + if args.simulator_base_url: + cmd.extend(["--simulator_base_url", args.simulator_base_url]) + if not args.use_llm_judge: + cmd.append("--no_llm_judge") + if args.judge_model: + cmd.extend(["--judge_model", args.judge_model]) + + children.append(subprocess.Popen(cmd)) + time.sleep(0.1) + + print("Shards started. Press Ctrl+C to stop all shards.") + while True: + for p in children: + rc = p.poll() + if rc is not None: + raise RuntimeError( + f"Shard exited early: pid={p.pid}, code={rc}" + ) + time.sleep(1.0) + except KeyboardInterrupt: + pass + finally: + for p in children: + try: + p.terminate() + except Exception: + pass + for p in children: + try: + p.wait(timeout=5) + except Exception: + try: + p.kill() + except Exception: + pass + return + + # Single shard mode + from opentinker.environment.base_game_server import run_game_server + + run_game_server( + game_class=LLMUserGame, + host=args.host, + port=args.port, + stats_class=None, + simulator_model=args.simulator_model, + simulator_base_url=args.simulator_base_url, + task_type=args.task_type, + max_turns=args.max_turns, + temperature=args.temperature, + use_llm_judge=args.use_llm_judge, + judge_model=args.judge_model, + ) + + +if __name__ == "__main__": + main() diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 3680053..7bcbe46 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -1111,6 +1111,40 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: f"Job {job.job_id}: ✓ LoRA enabled: rank={lora_rank}, alpha={lora_alpha}, target_modules={target_modules}" ) + # Forward world model loss settings if specified + actor_config = job.config.get("actor", {}) + if actor_config.get("use_world_model_loss"): + cmd.append("actor_rollout_ref.actor.use_world_model_loss=true") + wm_coef = actor_config.get("world_model_loss_coef", 0.1) + cmd.append(f"actor_rollout_ref.actor.world_model_loss_coef={wm_coef}") + logger.info(f"Job {job.job_id}: ✓ World Model Loss enabled: coef={wm_coef}") + + # Forward WM active sampling settings if specified + if actor_config.get("wm_active_sampling"): + cmd.append("actor_rollout_ref.actor.wm_active_sampling=true") + wm_active_coef = actor_config.get("wm_active_sampling_coef", 0.5) + cmd.append( + f"actor_rollout_ref.actor.wm_active_sampling_coef={wm_active_coef}" + ) + logger.info( + f"Job {job.job_id}: ✓ WM Active Sampling enabled: coef={wm_active_coef}" + ) + + # Forward WM dynamic entropy settings if specified + # Use + prefix to add new config keys that may not exist in the base schema + wm_dynamic_entropy = actor_config.get("wm_dynamic_entropy", {}) + if wm_dynamic_entropy.get("enabled"): + cmd.append("+actor_rollout_ref.actor.wm_dynamic_entropy.enabled=true") + beta_0 = wm_dynamic_entropy.get("beta_0", 0.001) + beta_1 = wm_dynamic_entropy.get("beta_1", 0.01) + gamma = wm_dynamic_entropy.get("gamma", 1.0) + cmd.append(f"+actor_rollout_ref.actor.wm_dynamic_entropy.beta_0={beta_0}") + cmd.append(f"+actor_rollout_ref.actor.wm_dynamic_entropy.beta_1={beta_1}") + cmd.append(f"+actor_rollout_ref.actor.wm_dynamic_entropy.gamma={gamma}") + logger.info( + f"Job {job.job_id}: ✓ WM Dynamic Entropy enabled: beta_0={beta_0}, beta_1={beta_1}, gamma={gamma}" + ) + logger.info(f"Job {job.job_id}: Launching server with command: {' '.join(cmd)}") # Create log files for stdout and stderr with human-readable timestamp diff --git a/opentinker/scripts/launch_scheduler.sh b/opentinker/scripts/launch_scheduler.sh index 581e7cf..5dff4df 100755 --- a/opentinker/scripts/launch_scheduler.sh +++ b/opentinker/scripts/launch_scheduler.sh @@ -6,11 +6,19 @@ export CUDA_HOME=$HOME/local/cuda-12.8 export PATH=$CUDA_HOME/bin:$PATH export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export ROLLOUT_TRACE_DIR="/home/haofeiy2/OpenTinker/traces" +# DISABLED - causes disk space issues +# export ROLLOUT_TRACE_DIR="/home/haofeiy2/OpenTinker/traces" export NVCC_EXECUTABLE=$CUDA_HOME/bin/nvcc export TORCH_CUDA_ARCH_LIST="9.0" export FLASHINFER_HOMOGENEOUS_MS=1 +# Disable sleep mode to avoid cumem allocator CUDA errors (V1 required for async engine) +export VLLM_DISABLE_SLEEP_MODE=1 + +# Limit Ray object store to prevent disk space issues +# Default 200GB is too large and causes spilling to disk +export RAY_object_store_memory=30000000000 # 50GB max + # Default configuration AVAILABLE_GPUS="[0,1,2,3,4,5,6,7,8,9]" PORT_RANGE="null" # Set to null for auto-detection diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh new file mode 100755 index 0000000..7f4168c --- /dev/null +++ b/opentinker/scripts/run_alfworld.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# ALFWorld Training & Inference Script +# +# This script runs ALFWorld RL training or inference with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env +# Terminal 3: bash run_alfworld.sh client +# +# For Inference/Evaluation (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env-eval +# Terminal 3: bash run_alfworld.sh inference model_path=/path/to/checkpoint + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT=8089 +ENV_PORT=8091 +GPUS='[0,1,2,3,4,5,6,7,8,9]' +NUM_GPUS=4 + +# Fix vLLM v1 cumem allocator issue (V1 is required for async engine) +# Disable sleep mode to avoid cumem allocator CUDA errors +export VLLM_DISABLE_SLEEP_MODE=1 + +# Activate conda environment +source ~/anaconda3/etc/profile.d/conda.sh +conda activate opentinker + +# Change to OpenTinker directory +cd /home/haofeiy2/OpenTinker + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + env|2) + echo "========================================" + echo "Step 2: Starting ALFWorld Environment Server on ports $ENV_PORT-$((ENV_PORT+7)) (8 shards)" + echo "========================================" + python opentinker/environment/alfworld/alfworld_server.py \ + --port $ENV_PORT \ + --shards 8 \ + --split train \ + --max_steps 50 + ;; + + env-eval) + echo "========================================" + echo "Step 2 (Eval): Starting ALFWorld Environment Server for Evaluation" + echo "========================================" + python opentinker/environment/alfworld/alfworld_server.py \ + --port $ENV_PORT \ + --shards 1 \ + --split eval_in_distribution \ + --max_steps 50 + ;; + + client|3) + echo "========================================" + echo "Step 3: Running ALFWorld RL Client" + echo "========================================" + python opentinker/client/alfworld_rl.py \ + tokenizer_path=Qwen/Qwen2.5-3B-Instruct \ + batch_size=16 \ + val_batch_size=32 \ + num_epochs=5 \ + save_freq=1000 \ + test_freq=100 \ + num_gpus=$NUM_GPUS \ + scheduler_url=http://0.0.0.0:$SCHEDULER_PORT \ + interaction.config.env_port=$ENV_PORT \ + interaction.config.env_host=0.0.0.0 \ + interaction.config.env_shards=8 + ;; + + inference|4) + echo "========================================" + echo "Step 3 (Inference): Running ALFWorld Evaluation" + echo "========================================" + # Pass remaining arguments (e.g., model_path=/path/to/checkpoint) + shift # Remove 'inference' from args + python opentinker/client/alfworld_inference.py \ + num_gpus=$NUM_GPUS \ + scheduler_url=http://0.0.0.0:$SCHEDULER_PORT \ + env_endpoint=http://0.0.0.0:$ENV_PORT \ + split=eval_in_distribution \ + "$@" + ;; + + *) + echo "ALFWorld Training & Inference Script" + echo "" + echo "Usage: $0 {scheduler|env|env-eval|client|inference}" + echo " $0 {1|2|3|4}" + echo "" + echo "=== For Training (3 terminals) ===" + echo " Terminal 1: $0 scheduler # Start scheduler (port $SCHEDULER_PORT)" + echo " Terminal 2: $0 env # Start environment server (train split)" + echo " Terminal 3: $0 client # Start RL training client" + echo "" + echo "=== For Inference/Evaluation (3 terminals) ===" + echo " Terminal 1: $0 scheduler # Start scheduler (port $SCHEDULER_PORT)" + echo " Terminal 2: $0 env-eval # Start environment server (eval split)" + echo " Terminal 3: $0 inference model_path=/path/to/checkpoint" + echo "" + echo "Inference options:" + echo " model_path=... # Path to trained checkpoint (REQUIRED)" + echo " max_samples=N # Limit evaluation samples" + echo " output_path=... # Save results to file" + echo " split=... # eval_in_distribution (default) or eval_out_of_distribution" + echo "" + echo "Configuration:" + echo " SCHEDULER_PORT=$SCHEDULER_PORT" + echo " ENV_PORT=$ENV_PORT" + echo " GPUS=$GPUS" + echo " NUM_GPUS=$NUM_GPUS" + ;; +esac diff --git a/opentinker/server/config/actor/actor.yaml b/opentinker/server/config/actor/actor.yaml index 43f576a..9f92a1b 100755 --- a/opentinker/server/config/actor/actor.yaml +++ b/opentinker/server/config/actor/actor.yaml @@ -92,6 +92,35 @@ ppo_epochs: 1 # Shuffle training data across PPO epochs shuffle: false +# World model loss: auxiliary SFT loss for predicting environment observations +# This helps the model learn a world model of the environment in multi-turn agentic tasks +use_world_model_loss: false + +# Coefficient for world model loss +world_model_loss_coef: 0.1 + +# WM Active Sampling: use WM uncertainty to weight policy gradient +# High uncertainty samples get higher advantage -> agent learns more from uncertain states +# Low uncertainty (redundant) samples get lower weight -> less gradient +# Uses OLD log_prob (lagged theta^-) for stable weights across PPO epochs +# Weights computed at MINI-BATCH level for consistency across gradient accumulation +# Auto-detects observation_mask format (bool mask vs index list) +wm_active_sampling: false + +# Alpha coefficient for WM active sampling: weight = exp(alpha * z_score) +# Higher alpha = stronger emphasis on uncertainty differences +wm_active_sampling_coef: 0.5 + +# Minimum weight for WM active sampling (allows down-weighting redundant samples) +wm_active_wmin: 0.5 + +# Maximum weight for WM active sampling (caps up-weighting uncertain samples) +wm_active_wmax: 2.0 + +# Only apply WM weights to positive advantages (avoid penalizing exploration) +# Useful when reward is sparse and early exploration often fails +wm_active_positive_only: false + # checkpoint configs checkpoint: # Target dataclass for this configuration diff --git a/opentinker/server/config/ppo_trainer.yaml b/opentinker/server/config/ppo_trainer.yaml index 53d1c53..ec65106 100755 --- a/opentinker/server/config/ppo_trainer.yaml +++ b/opentinker/server/config/ppo_trainer.yaml @@ -221,7 +221,7 @@ trainer: del_local_ckpt_after_load: False # Default local directory for saving checkpoints - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_local_dir: /mnt/disk3_from_server2/haofeiy2/opentinker_checkpoints/${trainer.project_name}/${trainer.experiment_name} # Maximum number of actor checkpoints to keep max_actor_ckpt_to_keep: null diff --git a/opentinker/server/generic_agent_loop.py b/opentinker/server/generic_agent_loop.py index 3930bbc..c5c64e0 100755 --- a/opentinker/server/generic_agent_loop.py +++ b/opentinker/server/generic_agent_loop.py @@ -117,6 +117,16 @@ def __init__( self.response_mask: list[int] = [] self.response_logprobs: list[float] = [] + # Observation mask for world model loss + # observation_mask=1 for environment observation tokens (used for world model SFT loss) + # observation_mask=0 for LLM-generated action tokens + self.observation_mask: list[int] = [] + + # Turn index for each token (used for turn-wise dynamic entropy coefficient) + # turn_ids[i] = which turn token i belongs to (0-indexed) + # This allows computing per-turn WM uncertainty and applying different entropy weights + self.turn_ids: list[int] = [] + # Turn tracking self.user_turns = 0 self.assistant_turns = 0 @@ -453,6 +463,13 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu # Ensure env_info exists for all samples (even if empty) for consistent DataProto.concat output.extra_fields["env_info"] = agent_data.extra_fields.get("env_info", []) output.extra_fields["turn_scores"] = agent_data.turn_scores + # Add observation_mask for world model loss (marks environment feedback tokens) + output.extra_fields["observation_mask"] = agent_data.observation_mask[ + : self.response_length + ] + # Add turn_ids for turn-wise dynamic entropy coefficient + # turn_ids[i] = which turn token i belongs to (0-indexed) + output.extra_fields["turn_ids"] = agent_data.turn_ids[: self.response_length] # Add any other extra fields (except the ones we already set) for key, value in agent_data.extra_fields.items(): if key not in output.extra_fields: @@ -503,6 +520,7 @@ async def _handle_generating_state( """Handle the generating state: generate LLM response. The generated tokens are marked with mask=1 (included in loss computation). + Turn IDs are recorded for turn-wise dynamic entropy coefficient. """ import time @@ -530,10 +548,17 @@ async def _handle_generating_state( agent_data.response_ids = [eos_token_id] agent_data.prompt_ids.append(eos_token_id) agent_data.response_mask.append(1) + agent_data.observation_mask.append(0) # EOS is LLM-generated + agent_data.turn_ids.append( + agent_data.assistant_turns + ) # Current turn return GenericAgentState.TERMINATED + # Current turn index (0-indexed, based on assistant turns) + current_turn = agent_data.assistant_turns + print( - f"[GenericAgentLoop DEBUG] _handle_generating_state START: request_id={agent_data.request_id}, prompt_len={len(agent_data.prompt_ids)}" + f"[GenericAgentLoop DEBUG] _handle_generating_state START: request_id={agent_data.request_id}, prompt_len={len(agent_data.prompt_ids)}, turn={current_turn}" ) start_time = time.time() with simple_timer("generate_sequences", agent_data.metrics): @@ -573,6 +598,13 @@ async def _handle_generating_state( agent_data.response_mask += [1] * len( agent_data.response_ids ) # mask=1 for LLM tokens + agent_data.observation_mask += [0] * len( + agent_data.response_ids + ) # observation_mask=0 for LLM-generated actions + + # Record turn ID for each token (used for turn-wise dynamic entropy coefficient) + # current_turn was captured BEFORE incrementing assistant_turns + agent_data.turn_ids += [current_turn] * len(agent_data.response_ids) if response_log_probs: agent_data.response_logprobs += response_log_probs @@ -672,9 +704,15 @@ async def _handle_interacting_state( return GenericAgentState.TERMINATED # Update prompt_ids and response_mask - # mask=0 for environment observation tokens (not included in loss) + # mask=0 for environment observation tokens (not included in policy loss) agent_data.prompt_ids += response_ids agent_data.response_mask += [0] * len(response_ids) + # observation_mask=1 for environment observation tokens (used for world model SFT loss) + agent_data.observation_mask += [1] * len(response_ids) + # turn_ids: observation belongs to the previous turn (action that caused this observation) + # Use (assistant_turns - 1) since assistant_turns was already incremented in _handle_generating_state + obs_turn = max(0, agent_data.assistant_turns - 1) + agent_data.turn_ids += [obs_turn] * len(response_ids) if agent_data.response_logprobs: # Pad logprobs with 0.0 for observation tokens diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 809be9a..210d219 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -1120,6 +1120,41 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) + # ===================================================================== + # WM-Guided Dynamic Entropy: Compute beta_token at trainer level + # This ensures stable weights across PPO epochs (using lagged θ^-) + # ===================================================================== + actor_config = self.config.actor_rollout_ref.actor + # Handle both dict and dataclass/OmegaConf access patterns + if hasattr(actor_config, "get"): + wm_dynamic_entropy_config = actor_config.get( + "wm_dynamic_entropy", {} + ) + elif hasattr(actor_config, "wm_dynamic_entropy"): + wm_dynamic_entropy_config = actor_config.wm_dynamic_entropy + else: + wm_dynamic_entropy_config = {} + + # Check if enabled + is_enabled = False + if isinstance(wm_dynamic_entropy_config, dict): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + elif hasattr(wm_dynamic_entropy_config, "enabled"): + is_enabled = wm_dynamic_entropy_config.enabled + elif hasattr(wm_dynamic_entropy_config, "get"): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + + print( + f"[WM Dynamic Entropy] DEBUG: is_enabled = {is_enabled}, config type = {type(wm_dynamic_entropy_config)}" + ) + + if is_enabled: + with marked_timer("wm_beta_token", timing_raw, color="magenta"): + batch, wm_metrics = self.trainer._compute_wm_beta_token( + batch, wm_dynamic_entropy_config + ) + metrics.update(wm_metrics) + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): diff --git a/opentinker/server/launch_http_server.py b/opentinker/server/launch_http_server.py index 6c2561d..928e767 100755 --- a/opentinker/server/launch_http_server.py +++ b/opentinker/server/launch_http_server.py @@ -22,6 +22,8 @@ def main(cfg): os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["HYDRA_FULL_ERROR"] = "1" + # Disable sleep mode to avoid cumem allocator issues (CUDA Error: invalid argument) + os.environ["VLLM_DISABLE_SLEEP_MODE"] = "1" from omegaconf import open_dict import logging @@ -123,7 +125,9 @@ def main(cfg): cfg.trainer.save_freq = 500 cfg.trainer.test_freq = 500 cfg.trainer.total_epochs = 15 - cfg.trainer.default_local_dir = "/workspace/verl/verl/ckpts" + cfg.trainer.default_local_dir = os.path.expanduser( + "/mnt/disk1_from_server2/haofeiy2/opentinker_checkpoints" + ) # --------------------------------------------------------- # Agent Loop Configuration @@ -138,8 +142,11 @@ def main(cfg): logger.info("Agent Loop Mode Enabled") logger.info("=" * 60) + # Async engine requires V1, so force it. VLLM_DISABLE_SLEEP_MODE=1 handles cumem issues. os.environ["VLLM_USE_V1"] = "1" - logger.info("Set VLLM_USE_V1=1 for async rollout") + logger.info( + "VLLM_USE_V1=1 for async rollout (sleep mode disabled to avoid cumem issues)" + ) # Increase Ray's memory threshold to avoid premature OOM kills # Default is 0.95 (95%), we increase to 0.98 (98%)