Skip to content

Commit 42de6fe

Browse files
committed
Adjust muon+adam params, add transformers, add gab, benchmark script, torch compile, autocast and new torch version fixes
1 parent 8350b4e commit 42de6fe

21 files changed

Lines changed: 1613 additions & 58 deletions

LICENSE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ and/or files, see the individual readmes and/or license files for each one withi
55
subdirectories within cpp/external. Additionally, cpp/core/sha2.cpp derives from another piece of
66
external code and embeds its own license within that file.
77

8+
Some parts of python/katago/model_pytorch.py and a few other files, where noted, are modifications
9+
of code from other open source authors.
10+
811
Aside from the above, the license for all OTHER content in this repo is as follows:
912

1013
----------------------------------------

python/benchmark_fresh_model.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
#!/usr/bin/python3
2+
import sys
3+
import os
4+
import argparse
5+
import time
6+
import math
7+
import numpy as np
8+
9+
import torch
10+
import torch._dynamo
11+
torch._dynamo.config.recompile_limit = 32
12+
import torch.nn
13+
14+
from katago.train import modelconfigs
15+
from katago.train.model_pytorch import Model
16+
from katago.train.metrics_pytorch import Metrics
17+
from katago.train import data_processing_pytorch
18+
19+
20+
def main():
21+
description = """
22+
Benchmark a fresh randomly-initialized model. Reports parameter counts by tensor,
23+
then measures forward-pass and forward+backward+optimizer-step timing.
24+
"""
25+
26+
parser = argparse.ArgumentParser(description=description)
27+
parser.add_argument('-model-kind', help='Model config name, e.g. b8c192nbt-fson-mish-rvglr-bnh', required=True)
28+
parser.add_argument('-optimizer', help='Optimizer to use', choices=['sgd', 'adam', 'muon'], default='sgd')
29+
parser.add_argument('-batch-size', help='Batch size', type=int, required=True)
30+
parser.add_argument('-data', help='Path to npz data file (e.g. ../python/testdata/benchmark_data_1024.npz)', required=True)
31+
parser.add_argument('-gpu', help='GPU device index', type=int, default=0)
32+
parser.add_argument('-pos-len', help='Board size', type=int, default=19)
33+
parser.add_argument('-num-iters', help='Number of benchmark iterations', type=int, default=20)
34+
parser.add_argument('-warmup-iters', help='Number of warmup iterations', type=int, default=5)
35+
parser.add_argument('-print-per-tensor-counts', help='Print parameter counts per tensor', action='store_true')
36+
parser.add_argument('-no-compile', help='Do not torch.compile', action='store_true')
37+
parser.add_argument('-use-tf32-matmul', help='Reduce float32 precision for speed on some gpus', action='store_true')
38+
args = vars(parser.parse_args())
39+
40+
model_kind = args["model_kind"]
41+
optimizer_kind = args["optimizer"]
42+
batch_size = args["batch_size"]
43+
data_path = args["data"]
44+
gpu_idx = args["gpu"]
45+
pos_len = args["pos_len"]
46+
num_iters = args["num_iters"]
47+
warmup_iters = args["warmup_iters"]
48+
print_per_tensor = args["print_per_tensor_counts"]
49+
no_compile = args["no_compile"]
50+
use_tf32_matmul = args["use_tf32_matmul"]
51+
52+
device = torch.device(f"cuda:{gpu_idx}")
53+
54+
if use_tf32_matmul:
55+
torch.set_float32_matmul_precision('high')
56+
print("float32 matmul precision: high (TF32)")
57+
else:
58+
print("float32 matmul precision: default")
59+
torch.cuda.set_device(device)
60+
61+
# Load model config and create model
62+
assert model_kind in modelconfigs.config_of_name, f"Unknown model kind: {model_kind}, available: {list(modelconfigs.config_of_name.keys())}"
63+
model_config = modelconfigs.config_of_name[model_kind]
64+
print(f"Model kind: {model_kind}")
65+
print(f"Optimizer: {optimizer_kind}")
66+
print(f"Batch size: {batch_size}")
67+
print(f"Device: {device}")
68+
print()
69+
70+
raw_model = Model(model_config, pos_len)
71+
raw_model.initialize()
72+
raw_model.to(device)
73+
74+
if no_compile:
75+
print("torch.compile: disabled (-no-compile)")
76+
model = raw_model
77+
else:
78+
print("torch.compile: enabled (mode=default)")
79+
model = torch.compile(raw_model, mode="default")
80+
print()
81+
82+
# Report parameter counts
83+
print("=" * 80)
84+
print("PARAMETER COUNTS")
85+
print("=" * 80)
86+
total_params = 0
87+
for name, param in raw_model.named_parameters():
88+
n = param.numel()
89+
total_params += n
90+
if print_per_tensor:
91+
print(f" {n:>12,} {str(list(param.shape)):>30s} {name}")
92+
if print_per_tensor:
93+
print()
94+
print(f" Total: {total_params:,} parameters")
95+
print()
96+
97+
# Also report by reg group
98+
reg_dict = {}
99+
raw_model.add_reg_dict(reg_dict)
100+
print("Parameters by group:")
101+
for group_name in reg_dict:
102+
group_params = sum(p.numel() for p in reg_dict[group_name])
103+
if group_params > 0:
104+
print(f" {group_name:>20s}: {group_params:>12,}")
105+
print()
106+
107+
# Set up optimizer
108+
param_groups = []
109+
for group_name in reg_dict:
110+
if len(reg_dict[group_name]) > 0:
111+
is_muon_suitable = group_name in ("normal", "normal_attn", "normal_gab", "gab_mlp")
112+
param_groups.append({
113+
"params": reg_dict[group_name],
114+
"group_name": group_name,
115+
"lr": 1e-5,
116+
"weight_decay": 0.01,
117+
"use_muon": is_muon_suitable,
118+
})
119+
120+
if optimizer_kind == "adam":
121+
optimizer = torch.optim.AdamW(param_groups, lr=1e-5)
122+
elif optimizer_kind == "muon":
123+
from muon.muon import SingleDeviceMuonWithAuxAdam
124+
optimizer = SingleDeviceMuonWithAuxAdam(param_groups, adjust_lr_fn="match_rms_adamw")
125+
else:
126+
optimizer = torch.optim.SGD(param_groups, lr=1e-5, momentum=0.9)
127+
128+
metrics_obj = Metrics(batch_size, 1, raw_model)
129+
130+
# Load data
131+
print(f"Loading data from {data_path} ...")
132+
batch = load_batch(data_path, batch_size, pos_len, model_config, device)
133+
print(f"Data loaded, batch size = {batch_size}")
134+
print()
135+
136+
# Set model to training mode
137+
raw_model.train()
138+
139+
# Benchmark forward only
140+
print("=" * 80)
141+
print("FORWARD PASS BENCHMARK")
142+
print("=" * 80)
143+
forward_times = benchmark_forward(model, batch, num_iters, warmup_iters)
144+
print_timing_stats("Forward", forward_times)
145+
print()
146+
147+
# Benchmark forward + backward + optimizer step with attribution
148+
print("=" * 80)
149+
print("FORWARD + BACKWARD + OPTIMIZER STEP BENCHMARK")
150+
print("=" * 80)
151+
fwd_times, bwd_times, opt_times = benchmark_full_step(
152+
model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters,
153+
)
154+
print_timing_stats("Forward ", fwd_times)
155+
print_timing_stats("Backward", bwd_times)
156+
print_timing_stats("Opt step", opt_times)
157+
total_times = [f + b + o for f, b, o in zip(fwd_times, bwd_times, opt_times)]
158+
print_timing_stats("Total ", total_times)
159+
print()
160+
161+
# Print proportions
162+
mean_fwd = sum(fwd_times) / len(fwd_times)
163+
mean_bwd = sum(bwd_times) / len(bwd_times)
164+
mean_opt = sum(opt_times) / len(opt_times)
165+
mean_total = mean_fwd + mean_bwd + mean_opt
166+
print(f" Time attribution (with sync between phases):")
167+
print(f" Forward: {mean_fwd*1000:8.2f} ms ({100*mean_fwd/mean_total:5.1f}%)")
168+
print(f" Backward: {mean_bwd*1000:8.2f} ms ({100*mean_bwd/mean_total:5.1f}%)")
169+
print(f" Opt step: {mean_opt*1000:8.2f} ms ({100*mean_opt/mean_total:5.1f}%)")
170+
print(f" Total: {mean_total*1000:8.2f} ms")
171+
print()
172+
173+
# Benchmark true throughput without intermediate syncs
174+
print("=" * 80)
175+
print("FULL STEP THROUGHPUT (no intermediate sync)")
176+
print("=" * 80)
177+
throughput_times = benchmark_full_step_throughput(
178+
model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters,
179+
)
180+
print_timing_stats("Total ", throughput_times)
181+
print()
182+
183+
184+
def load_batch(data_path, batch_size, pos_len, model_config, device):
185+
"""Load a single batch from an npz file."""
186+
num_bin_features = modelconfigs.get_num_bin_input_features(model_config)
187+
num_global_features = modelconfigs.get_num_global_input_features(model_config)
188+
include_qvalues = model_config["version"] >= 16
189+
190+
with np.load(data_path) as npz:
191+
binaryInputNCHWPacked = npz["binaryInputNCHWPacked"][:batch_size]
192+
globalInputNC = npz["globalInputNC"][:batch_size]
193+
policyTargetsNCMove = npz["policyTargetsNCMove"][:batch_size].astype(np.float32)
194+
globalTargetsNC = npz["globalTargetsNC"][:batch_size]
195+
scoreDistrN = npz["scoreDistrN"][:batch_size].astype(np.float32)
196+
valueTargetsNCHW = npz["valueTargetsNCHW"][:batch_size].astype(np.float32)
197+
if include_qvalues and "qValueTargetsNCMove" in npz:
198+
qValueTargetsNCMove = npz["qValueTargetsNCMove"][:batch_size].astype(np.float32)
199+
else:
200+
qValueTargetsNCMove = None
201+
202+
binaryInputNCHW = np.unpackbits(binaryInputNCHWPacked, axis=2)
203+
assert binaryInputNCHW.shape[2] == ((pos_len * pos_len + 7) // 8) * 8
204+
binaryInputNCHW = binaryInputNCHW[:, :, :pos_len * pos_len]
205+
binaryInputNCHW = np.reshape(binaryInputNCHW, (
206+
binaryInputNCHW.shape[0], binaryInputNCHW.shape[1], pos_len, pos_len
207+
)).astype(np.float32)
208+
209+
assert binaryInputNCHW.shape[1] == num_bin_features
210+
assert globalInputNC.shape[1] == num_global_features
211+
212+
(h_base, h_builder) = data_processing_pytorch.build_history_matrices(model_config, device)
213+
214+
batch_binaryInputNCHW = torch.from_numpy(binaryInputNCHW).to(device)
215+
batch_globalInputNC = torch.from_numpy(globalInputNC).to(device)
216+
batch_globalTargetsNC = torch.from_numpy(globalTargetsNC).to(device)
217+
218+
(batch_binaryInputNCHW, batch_globalInputNC) = data_processing_pytorch.apply_history_matrices(
219+
model_config, batch_binaryInputNCHW, batch_globalInputNC, batch_globalTargetsNC, h_base, h_builder
220+
)
221+
222+
batch = dict(
223+
binaryInputNCHW=batch_binaryInputNCHW.contiguous(),
224+
globalInputNC=batch_globalInputNC,
225+
policyTargetsNCMove=torch.from_numpy(policyTargetsNCMove).to(device),
226+
globalTargetsNC=batch_globalTargetsNC,
227+
scoreDistrN=torch.from_numpy(scoreDistrN).to(device),
228+
valueTargetsNCHW=torch.from_numpy(valueTargetsNCHW).to(device),
229+
)
230+
if qValueTargetsNCMove is not None:
231+
batch["qValueTargetsNCMove"] = torch.from_numpy(qValueTargetsNCMove).to(device)
232+
return batch
233+
234+
235+
def benchmark_forward(model, batch, num_iters, warmup_iters):
236+
"""Benchmark forward pass only."""
237+
times = []
238+
for i in range(warmup_iters + num_iters):
239+
torch.cuda.synchronize()
240+
t0 = time.perf_counter()
241+
242+
with torch.no_grad():
243+
model_outputs = model(
244+
batch["binaryInputNCHW"],
245+
batch["globalInputNC"],
246+
)
247+
248+
torch.cuda.synchronize()
249+
t1 = time.perf_counter()
250+
251+
if i >= warmup_iters:
252+
times.append(t1 - t0)
253+
return times
254+
255+
256+
def benchmark_full_step(model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters):
257+
"""Benchmark forward + backward + optimizer step, returning separate timings."""
258+
fwd_times = []
259+
bwd_times = []
260+
opt_times = []
261+
262+
for i in range(warmup_iters + num_iters):
263+
optimizer.zero_grad(set_to_none=True)
264+
265+
# Forward
266+
torch.cuda.synchronize()
267+
t_fwd_start = time.perf_counter()
268+
269+
model_outputs = model(
270+
batch["binaryInputNCHW"],
271+
batch["globalInputNC"],
272+
)
273+
postprocessed = raw_model.postprocess_output(model_outputs)
274+
metrics = metrics_obj.metrics_dict_batchwise(
275+
raw_model,
276+
postprocessed,
277+
extra_outputs=None,
278+
batch=batch,
279+
is_training=True,
280+
soft_policy_weight_scale=1.0,
281+
disable_optimistic_policy=False,
282+
meta_kata_only_soft_policy=False,
283+
value_loss_scale=1.0,
284+
td_value_loss_scales=[0.4, 1.0, 1.0],
285+
seki_loss_scale=0.35,
286+
variance_time_loss_scale=0.5,
287+
main_loss_scale=1.0,
288+
intermediate_loss_scale=0.5 if raw_model.get_has_intermediate_head() else None,
289+
)
290+
loss = metrics["loss_sum"]
291+
292+
torch.cuda.synchronize()
293+
t_bwd_start = time.perf_counter()
294+
295+
# Backward
296+
loss.backward()
297+
298+
torch.cuda.synchronize()
299+
t_opt_start = time.perf_counter()
300+
301+
# Optimizer step
302+
optimizer.step()
303+
304+
torch.cuda.synchronize()
305+
t_opt_end = time.perf_counter()
306+
307+
if i >= warmup_iters:
308+
fwd_times.append(t_bwd_start - t_fwd_start)
309+
bwd_times.append(t_opt_start - t_bwd_start)
310+
opt_times.append(t_opt_end - t_opt_start)
311+
312+
return fwd_times, bwd_times, opt_times
313+
314+
315+
def benchmark_full_step_throughput(model, raw_model, optimizer, metrics_obj, batch, model_config, num_iters, warmup_iters):
316+
"""Benchmark full training step without intermediate syncs, for true throughput measurement."""
317+
times = []
318+
319+
for i in range(warmup_iters + num_iters):
320+
torch.cuda.synchronize()
321+
t0 = time.perf_counter()
322+
323+
optimizer.zero_grad(set_to_none=True)
324+
model_outputs = model(
325+
batch["binaryInputNCHW"],
326+
batch["globalInputNC"],
327+
)
328+
postprocessed = raw_model.postprocess_output(model_outputs)
329+
metrics = metrics_obj.metrics_dict_batchwise(
330+
raw_model,
331+
postprocessed,
332+
extra_outputs=None,
333+
batch=batch,
334+
is_training=True,
335+
soft_policy_weight_scale=1.0,
336+
disable_optimistic_policy=False,
337+
meta_kata_only_soft_policy=False,
338+
value_loss_scale=1.0,
339+
td_value_loss_scales=[0.4, 1.0, 1.0],
340+
seki_loss_scale=0.35,
341+
variance_time_loss_scale=0.5,
342+
main_loss_scale=1.0,
343+
intermediate_loss_scale=0.5 if raw_model.get_has_intermediate_head() else None,
344+
)
345+
loss = metrics["loss_sum"]
346+
loss.backward()
347+
optimizer.step()
348+
349+
torch.cuda.synchronize()
350+
t1 = time.perf_counter()
351+
352+
if i >= warmup_iters:
353+
times.append(t1 - t0)
354+
355+
return times
356+
357+
358+
def print_timing_stats(label, times):
359+
mean = sum(times) / len(times)
360+
std = math.sqrt(sum((t - mean) ** 2 for t in times) / len(times))
361+
lo = min(times)
362+
hi = max(times)
363+
print(f" {label}: {mean*1000:8.2f} ms (std {std*1000:6.2f} ms, min {lo*1000:8.2f} ms, max {hi*1000:8.2f} ms)")
364+
365+
366+
if __name__ == "__main__":
367+
main()

python/clean_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
checkpoint_path = args["checkpoint"]
2121
output_path = args["output"]
2222

23-
data = torch.load(checkpoint_path,map_location="cpu")
23+
data = katago.train.load_model.load_checkpoint(checkpoint_path)
2424

2525
if "optimizer" in data:
2626
del data["optimizer"]

python/edit_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
output_json_to = args["output_json_to"]
2323
overwrite_checkpoint_from_json = args["overwrite_checkpoint_from_json"]
2424

25-
data = torch.load(checkpoint_path,map_location="cpu")
25+
data = katago.train.load_model.load_checkpoint(checkpoint_path)
2626

2727
if output_json_to is not None:
2828
assert output_json_to.endswith(".json")

0 commit comments

Comments
 (0)