-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
69 lines (62 loc) · 2.51 KB
/
main.py
File metadata and controls
69 lines (62 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from utils_fn import initialize
initialize(enable_ray=True)
import json
import os
import shutil
from recorder import Recorder
from ero import run_ero
import argparse
def main():
# parse
parser = argparse.ArgumentParser()
parser.add_argument("--size", "-s", type=int, default=1000, help="Size of the population")
parser.add_argument("--epochs", "-e", type=int, default=12, help="Number of evolutions")
parser.add_argument("--scale", type=float, default=0.15, help="To scale std")
args = parser.parse_args()
# configs
num_models = args.size # Size of the population
num_epochs = args.epochs # Number of evolutions
std_scale = args.scale # To scale std
ori_model_path = "models/base_model" # Path to local model
mean_model_path = "models/mean_models" # Folder used to storage the mean model when running es
best_model_path = "models/best_models" # Folder used to save the best model when running es
data_path = "data" # Folder used to storage tasks
output_path = "results" # Folder used to save results
std_path = "results/std.json" # Path to std json
clamp_range_path = "results/clamp_range.json" # Path to clamp range json
es_list = [
["351d6448", ["351d6448"]], # ["NAME FOR THIS RUN", ["LIST OF ARC TASKS"]]
["414297c0", ["414297c0"]],
["e6de6e8f", ["e6de6e8f"]],
["e7a25a18", ["e7a25a18"]],
["505fff84", ["505fff84"]],
["b1fc8b8e", ["b1fc8b8e"]],
["1a6449f1", ["1a6449f1"]],
["3194b014", ["3194b014"]],
["9b4c17c4", ["9b4c17c4"]],
["0a1d4ef5", ["0a1d4ef5"]],
["9a4bb226", ["9a4bb226"]],
["12422b43", ["12422b43"]],
["1c02dbbe", ["1c02dbbe"]],
["477d2879", ["477d2879"]],
["67b4a34d", ["67b4a34d"]],
]
with open(std_path, 'r') as f:
std = json.load(f)
for name in std.keys():
std[name] = std[name] * std_scale
with open(clamp_range_path, 'r') as f:
clamp_range = json.load(f)
recorder = Recorder(output_path, data_path)
# run ero
for es_name, es_tasks in es_list:
# reset mean model
for file in os.listdir(ori_model_path):
full_file = os.path.join(ori_model_path, file)
if os.path.isfile(full_file):
shutil.copy(full_file, mean_model_path)
# run
run_ero(es_name, ori_model_path, mean_model_path, num_models, std, clamp_range, es_tasks, num_epochs, recorder, output_path, best_model_path)
print()
if __name__ == "__main__":
main()