|
| 1 | +#!/usr/bin/python3 |
| 2 | +import sys |
| 3 | +import os |
| 4 | +import argparse |
| 5 | +import time |
| 6 | +import math |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch._dynamo |
| 11 | +torch._dynamo.config.recompile_limit = 32 |
| 12 | +import torch.nn |
| 13 | + |
| 14 | +from katago.train import modelconfigs |
| 15 | +from katago.train.model_pytorch import Model |
| 16 | +from katago.train.metrics_pytorch import Metrics |
| 17 | +from katago.train import data_processing_pytorch |
| 18 | + |
| 19 | + |
| 20 | +def main(): |
| 21 | + description = """ |
| 22 | + Benchmark a fresh randomly-initialized model. Reports parameter counts by tensor, |
| 23 | + then measures forward-pass and forward+backward+optimizer-step timing. |
| 24 | + """ |
| 25 | + |
| 26 | + parser = argparse.ArgumentParser(description=description) |
| 27 | + parser.add_argument('-model-kind', help='Model config name, e.g. b8c192nbt-fson-mish-rvglr-bnh', required=True) |
| 28 | + parser.add_argument('-optimizer', help='Optimizer to use', choices=['sgd', 'adam', 'muon'], default='sgd') |
| 29 | + parser.add_argument('-batch-size', help='Batch size', type=int, required=True) |
| 30 | + parser.add_argument('-data', help='Path to npz data file (e.g. ../python/testdata/benchmark_data_1024.npz)', required=True) |
| 31 | + parser.add_argument('-gpu', help='GPU device index', type=int, default=0) |
| 32 | + parser.add_argument('-pos-len', help='Board size', type=int, default=19) |
| 33 | + parser.add_argument('-num-iters', help='Number of benchmark iterations', type=int, default=20) |
| 34 | + parser.add_argument('-warmup-iters', help='Number of warmup iterations', type=int, default=5) |
| 35 | + parser.add_argument('-print-per-tensor-counts', help='Print parameter counts per tensor', action='store_true') |
| 36 | + parser.add_argument('-no-compile', help='Do not torch.compile', action='store_true') |
| 37 | + parser.add_argument('-use-tf32-matmul', help='Reduce float32 precision for speed on some gpus', action='store_true') |
| 38 | + args = vars(parser.parse_args()) |
| 39 | + |
| 40 | + model_kind = args["model_kind"] |
| 41 | + optimizer_kind = args["optimizer"] |
| 42 | + batch_size = args["batch_size"] |
| 43 | + data_path = args["data"] |
| 44 | + gpu_idx = args["gpu"] |
| 45 | + pos_len = args["pos_len"] |
| 46 | + num_iters = args["num_iters"] |
| 47 | + warmup_iters = args["warmup_iters"] |
| 48 | + print_per_tensor = args["print_per_tensor_counts"] |
| 49 | + no_compile = args["no_compile"] |
| 50 | + use_tf32_matmul = args["use_tf32_matmul"] |
| 51 | + |
| 52 | + device = torch.device(f"cuda:{gpu_idx}") |
| 53 | + |
| 54 | + if use_tf32_matmul: |
| 55 | + torch.set_float32_matmul_precision('high') |
| 56 | + print("float32 matmul precision: high (TF32)") |
| 57 | + else: |
| 58 | + print("float32 matmul precision: default") |
| 59 | + torch.cuda.set_device(device) |
| 60 | + |
| 61 | + # Load model config and create model |
| 62 | + assert model_kind in modelconfigs.config_of_name, f"Unknown model kind: {model_kind}, available: {list(modelconfigs.config_of_name.keys())}" |
| 63 | + model_config = modelconfigs.config_of_name[model_kind] |
| 64 | + print(f"Model kind: {model_kind}") |
| 65 | + print(f"Optimizer: {optimizer_kind}") |
| 66 | + print(f"Batch size: {batch_size}") |
| 67 | + print(f"Device: {device}") |
| 68 | + print() |
| 69 | + |
| 70 | + raw_model = Model(model_config, pos_len) |
| 71 | + raw_model.initialize() |
| 72 | + raw_model.to(device) |
| 73 | + |
| 74 | + if no_compile: |
| 75 | + print("torch.compile: disabled (-no-compile)") |
| 76 | + model = raw_model |
| 77 | + else: |
| 78 | + print("torch.compile: enabled (mode=default)") |
| 79 | + model = torch.compile(raw_model, mode="default") |
| 80 | + print() |
| 81 | + |
| 82 | + # Report parameter counts |
| 83 | + print("=" * 80) |
| 84 | + print("PARAMETER COUNTS") |
| 85 | + print("=" * 80) |
| 86 | + total_params = 0 |
| 87 | + for name, param in raw_model.named_parameters(): |
| 88 | + n = param.numel() |
| 89 | + total_params += n |
| 90 | + if print_per_tensor: |
| 91 | + print(f" {n:>12,} {str(list(param.shape)):>30s} {name}") |
| 92 | + if print_per_tensor: |
| 93 | + print() |
| 94 | + print(f" Total: {total_params:,} parameters") |
| 95 | + print() |
| 96 | + |
| 97 | + # Also report by reg group |
| 98 | + reg_dict = {} |
| 99 | + raw_model.add_reg_dict(reg_dict) |
| 100 | + print("Parameters by group:") |
| 101 | + for group_name in reg_dict: |
| 102 | + group_params = sum(p.numel() for p in reg_dict[group_name]) |
| 103 | + if group_params > 0: |
| 104 | + print(f" {group_name:>20s}: {group_params:>12,}") |
| 105 | + print() |
| 106 | + |
| 107 | + # Set up optimizer |
| 108 | + param_groups = [] |
| 109 | + for group_name in reg_dict: |
| 110 | + if len(reg_dict[group_name]) > 0: |
| 111 | + is_muon_suitable = group_name in ("normal", "normal_attn", "normal_gab", "gab_mlp") |
| 112 | + param_groups.append({ |
| 113 | + "params": reg_dict[group_name], |
| 114 | + "group_name": group_name, |
| 115 | + "lr": 1e-5, |
| 116 | + "weight_decay": 0.01, |
| 117 | + "use_muon": is_muon_suitable, |
| 118 | + }) |
| 119 | + |
| 120 | + if optimizer_kind == "adam": |
| 121 | + optimizer = torch.optim.AdamW(param_groups, lr=1e-5) |
| 122 | + elif optimizer_kind == "muon": |
| 123 | + from muon.muon import SingleDeviceMuonWithAuxAdam |
| 124 | + optimizer = SingleDeviceMuonWithAuxAdam(param_groups, adjust_lr_fn="match_rms_adamw") |
| 125 | + else: |
| 126 | + optimizer = torch.optim.SGD(param_groups, lr=1e-5, momentum=0.9) |
| 127 | + |
| 128 | + metrics_obj = Metrics(batch_size, 1, raw_model) |
| 129 | + |
| 130 | + # Load data |
| 131 | + print(f"Loading data from {data_path} ...") |
| 132 | + batch = load_batch(data_path, batch_size, pos_len, model_config, device) |
| 133 | + print(f"Data loaded, batch size = {batch_size}") |
| 134 | + print() |
| 135 | + |
| 136 | + # Set model to training mode |
| 137 | + raw_model.train() |
| 138 | + |
| 139 | + # Benchmark forward only |
| 140 | + print("=" * 80) |
| 141 | + print("FORWARD PASS BENCHMARK") |
| 142 | + print("=" * 80) |
| 143 | + forward_times = benchmark_forward(model, batch, num_iters, warmup_iters) |
| 144 | + print_timing_stats("Forward", forward_times) |
| 145 | + print() |
| 146 | + |
| 147 | + # Benchmark forward + backward + optimizer step with attribution |
| 148 | + print("=" * 80) |
| 149 | + print("FORWARD + BACKWARD + OPTIMIZER STEP BENCHMARK") |
| 150 | + print("=" * 80) |
| 151 | + fwd_times, bwd_times, opt_times = benchmark_full_step( |
| 152 | + model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters, |
| 153 | + ) |
| 154 | + print_timing_stats("Forward ", fwd_times) |
| 155 | + print_timing_stats("Backward", bwd_times) |
| 156 | + print_timing_stats("Opt step", opt_times) |
| 157 | + total_times = [f + b + o for f, b, o in zip(fwd_times, bwd_times, opt_times)] |
| 158 | + print_timing_stats("Total ", total_times) |
| 159 | + print() |
| 160 | + |
| 161 | + # Print proportions |
| 162 | + mean_fwd = sum(fwd_times) / len(fwd_times) |
| 163 | + mean_bwd = sum(bwd_times) / len(bwd_times) |
| 164 | + mean_opt = sum(opt_times) / len(opt_times) |
| 165 | + mean_total = mean_fwd + mean_bwd + mean_opt |
| 166 | + print(f" Time attribution (with sync between phases):") |
| 167 | + print(f" Forward: {mean_fwd*1000:8.2f} ms ({100*mean_fwd/mean_total:5.1f}%)") |
| 168 | + print(f" Backward: {mean_bwd*1000:8.2f} ms ({100*mean_bwd/mean_total:5.1f}%)") |
| 169 | + print(f" Opt step: {mean_opt*1000:8.2f} ms ({100*mean_opt/mean_total:5.1f}%)") |
| 170 | + print(f" Total: {mean_total*1000:8.2f} ms") |
| 171 | + print() |
| 172 | + |
| 173 | + # Benchmark true throughput without intermediate syncs |
| 174 | + print("=" * 80) |
| 175 | + print("FULL STEP THROUGHPUT (no intermediate sync)") |
| 176 | + print("=" * 80) |
| 177 | + throughput_times = benchmark_full_step_throughput( |
| 178 | + model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters, |
| 179 | + ) |
| 180 | + print_timing_stats("Total ", throughput_times) |
| 181 | + print() |
| 182 | + |
| 183 | + |
| 184 | +def load_batch(data_path, batch_size, pos_len, model_config, device): |
| 185 | + """Load a single batch from an npz file.""" |
| 186 | + num_bin_features = modelconfigs.get_num_bin_input_features(model_config) |
| 187 | + num_global_features = modelconfigs.get_num_global_input_features(model_config) |
| 188 | + include_qvalues = model_config["version"] >= 16 |
| 189 | + |
| 190 | + with np.load(data_path) as npz: |
| 191 | + binaryInputNCHWPacked = npz["binaryInputNCHWPacked"][:batch_size] |
| 192 | + globalInputNC = npz["globalInputNC"][:batch_size] |
| 193 | + policyTargetsNCMove = npz["policyTargetsNCMove"][:batch_size].astype(np.float32) |
| 194 | + globalTargetsNC = npz["globalTargetsNC"][:batch_size] |
| 195 | + scoreDistrN = npz["scoreDistrN"][:batch_size].astype(np.float32) |
| 196 | + valueTargetsNCHW = npz["valueTargetsNCHW"][:batch_size].astype(np.float32) |
| 197 | + if include_qvalues and "qValueTargetsNCMove" in npz: |
| 198 | + qValueTargetsNCMove = npz["qValueTargetsNCMove"][:batch_size].astype(np.float32) |
| 199 | + else: |
| 200 | + qValueTargetsNCMove = None |
| 201 | + |
| 202 | + binaryInputNCHW = np.unpackbits(binaryInputNCHWPacked, axis=2) |
| 203 | + assert binaryInputNCHW.shape[2] == ((pos_len * pos_len + 7) // 8) * 8 |
| 204 | + binaryInputNCHW = binaryInputNCHW[:, :, :pos_len * pos_len] |
| 205 | + binaryInputNCHW = np.reshape(binaryInputNCHW, ( |
| 206 | + binaryInputNCHW.shape[0], binaryInputNCHW.shape[1], pos_len, pos_len |
| 207 | + )).astype(np.float32) |
| 208 | + |
| 209 | + assert binaryInputNCHW.shape[1] == num_bin_features |
| 210 | + assert globalInputNC.shape[1] == num_global_features |
| 211 | + |
| 212 | + (h_base, h_builder) = data_processing_pytorch.build_history_matrices(model_config, device) |
| 213 | + |
| 214 | + batch_binaryInputNCHW = torch.from_numpy(binaryInputNCHW).to(device) |
| 215 | + batch_globalInputNC = torch.from_numpy(globalInputNC).to(device) |
| 216 | + batch_globalTargetsNC = torch.from_numpy(globalTargetsNC).to(device) |
| 217 | + |
| 218 | + (batch_binaryInputNCHW, batch_globalInputNC) = data_processing_pytorch.apply_history_matrices( |
| 219 | + model_config, batch_binaryInputNCHW, batch_globalInputNC, batch_globalTargetsNC, h_base, h_builder |
| 220 | + ) |
| 221 | + |
| 222 | + batch = dict( |
| 223 | + binaryInputNCHW=batch_binaryInputNCHW.contiguous(), |
| 224 | + globalInputNC=batch_globalInputNC, |
| 225 | + policyTargetsNCMove=torch.from_numpy(policyTargetsNCMove).to(device), |
| 226 | + globalTargetsNC=batch_globalTargetsNC, |
| 227 | + scoreDistrN=torch.from_numpy(scoreDistrN).to(device), |
| 228 | + valueTargetsNCHW=torch.from_numpy(valueTargetsNCHW).to(device), |
| 229 | + ) |
| 230 | + if qValueTargetsNCMove is not None: |
| 231 | + batch["qValueTargetsNCMove"] = torch.from_numpy(qValueTargetsNCMove).to(device) |
| 232 | + return batch |
| 233 | + |
| 234 | + |
| 235 | +def benchmark_forward(model, batch, num_iters, warmup_iters): |
| 236 | + """Benchmark forward pass only.""" |
| 237 | + times = [] |
| 238 | + for i in range(warmup_iters + num_iters): |
| 239 | + torch.cuda.synchronize() |
| 240 | + t0 = time.perf_counter() |
| 241 | + |
| 242 | + with torch.no_grad(): |
| 243 | + model_outputs = model( |
| 244 | + batch["binaryInputNCHW"], |
| 245 | + batch["globalInputNC"], |
| 246 | + ) |
| 247 | + |
| 248 | + torch.cuda.synchronize() |
| 249 | + t1 = time.perf_counter() |
| 250 | + |
| 251 | + if i >= warmup_iters: |
| 252 | + times.append(t1 - t0) |
| 253 | + return times |
| 254 | + |
| 255 | + |
| 256 | +def benchmark_full_step(model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters): |
| 257 | + """Benchmark forward + backward + optimizer step, returning separate timings.""" |
| 258 | + fwd_times = [] |
| 259 | + bwd_times = [] |
| 260 | + opt_times = [] |
| 261 | + |
| 262 | + for i in range(warmup_iters + num_iters): |
| 263 | + optimizer.zero_grad(set_to_none=True) |
| 264 | + |
| 265 | + # Forward |
| 266 | + torch.cuda.synchronize() |
| 267 | + t_fwd_start = time.perf_counter() |
| 268 | + |
| 269 | + model_outputs = model( |
| 270 | + batch["binaryInputNCHW"], |
| 271 | + batch["globalInputNC"], |
| 272 | + ) |
| 273 | + postprocessed = raw_model.postprocess_output(model_outputs) |
| 274 | + metrics = metrics_obj.metrics_dict_batchwise( |
| 275 | + raw_model, |
| 276 | + postprocessed, |
| 277 | + extra_outputs=None, |
| 278 | + batch=batch, |
| 279 | + is_training=True, |
| 280 | + soft_policy_weight_scale=1.0, |
| 281 | + disable_optimistic_policy=False, |
| 282 | + meta_kata_only_soft_policy=False, |
| 283 | + value_loss_scale=1.0, |
| 284 | + td_value_loss_scales=[0.4, 1.0, 1.0], |
| 285 | + seki_loss_scale=0.35, |
| 286 | + variance_time_loss_scale=0.5, |
| 287 | + main_loss_scale=1.0, |
| 288 | + intermediate_loss_scale=0.5 if raw_model.get_has_intermediate_head() else None, |
| 289 | + ) |
| 290 | + loss = metrics["loss_sum"] |
| 291 | + |
| 292 | + torch.cuda.synchronize() |
| 293 | + t_bwd_start = time.perf_counter() |
| 294 | + |
| 295 | + # Backward |
| 296 | + loss.backward() |
| 297 | + |
| 298 | + torch.cuda.synchronize() |
| 299 | + t_opt_start = time.perf_counter() |
| 300 | + |
| 301 | + # Optimizer step |
| 302 | + optimizer.step() |
| 303 | + |
| 304 | + torch.cuda.synchronize() |
| 305 | + t_opt_end = time.perf_counter() |
| 306 | + |
| 307 | + if i >= warmup_iters: |
| 308 | + fwd_times.append(t_bwd_start - t_fwd_start) |
| 309 | + bwd_times.append(t_opt_start - t_bwd_start) |
| 310 | + opt_times.append(t_opt_end - t_opt_start) |
| 311 | + |
| 312 | + return fwd_times, bwd_times, opt_times |
| 313 | + |
| 314 | + |
| 315 | +def benchmark_full_step_throughput(model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters): |
| 316 | + """Benchmark full training step without intermediate syncs, for true throughput measurement.""" |
| 317 | + times = [] |
| 318 | + |
| 319 | + for i in range(warmup_iters + num_iters): |
| 320 | + torch.cuda.synchronize() |
| 321 | + t0 = time.perf_counter() |
| 322 | + |
| 323 | + optimizer.zero_grad(set_to_none=True) |
| 324 | + model_outputs = model( |
| 325 | + batch["binaryInputNCHW"], |
| 326 | + batch["globalInputNC"], |
| 327 | + ) |
| 328 | + postprocessed = raw_model.postprocess_output(model_outputs) |
| 329 | + metrics = metrics_obj.metrics_dict_batchwise( |
| 330 | + raw_model, |
| 331 | + postprocessed, |
| 332 | + extra_outputs=None, |
| 333 | + batch=batch, |
| 334 | + is_training=True, |
| 335 | + soft_policy_weight_scale=1.0, |
| 336 | + disable_optimistic_policy=False, |
| 337 | + meta_kata_only_soft_policy=False, |
| 338 | + value_loss_scale=1.0, |
| 339 | + td_value_loss_scales=[0.4, 1.0, 1.0], |
| 340 | + seki_loss_scale=0.35, |
| 341 | + variance_time_loss_scale=0.5, |
| 342 | + main_loss_scale=1.0, |
| 343 | + intermediate_loss_scale=0.5 if raw_model.get_has_intermediate_head() else None, |
| 344 | + ) |
| 345 | + loss = metrics["loss_sum"] |
| 346 | + loss.backward() |
| 347 | + optimizer.step() |
| 348 | + |
| 349 | + torch.cuda.synchronize() |
| 350 | + t1 = time.perf_counter() |
| 351 | + |
| 352 | + if i >= warmup_iters: |
| 353 | + times.append(t1 - t0) |
| 354 | + |
| 355 | + return times |
| 356 | + |
| 357 | + |
| 358 | +def print_timing_stats(label, times): |
| 359 | + mean = sum(times) / len(times) |
| 360 | + std = math.sqrt(sum((t - mean) ** 2 for t in times) / len(times)) |
| 361 | + lo = min(times) |
| 362 | + hi = max(times) |
| 363 | + print(f" {label}: {mean*1000:8.2f} ms (std {std*1000:6.2f} ms, min {lo*1000:8.2f} ms, max {hi*1000:8.2f} ms)") |
| 364 | + |
| 365 | + |
| 366 | +if __name__ == "__main__": |
| 367 | + main() |
0 commit comments