Skip to content

Commit ac814fa

Browse files
committed
feat: add parallel dptb submission support
1 parent 71af276 commit ac814fa

1 file changed

Lines changed: 168 additions & 0 deletions

File tree

src/dprep/dptb_dpdispatcher.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import json
2+
import os
3+
import shutil
4+
5+
from dpdispatcher import Task, Submission, Machine, Resources
6+
import copy
7+
import itertools
8+
9+
10+
def merge_parameters(private_para, public_para):
11+
"""
12+
Merges private and public parameters by generating all combinations
13+
of list values in private_para and merging them into public_para.
14+
15+
:param private_para: Dictionary containing private parameters with possible list values.
16+
:param public_para: Dictionary containing public parameters.
17+
:return: List of dictionaries with all combinations merged.
18+
"""
19+
20+
# Helper function to recursively find all paths leading to list values
21+
def find_list_paths(d, current_path=[]):
22+
paths = []
23+
for k, v in d.items():
24+
new_path = current_path + [k]
25+
if isinstance(v, dict):
26+
paths.extend(find_list_paths(v, new_path))
27+
elif isinstance(v, list):
28+
paths.append((new_path, v))
29+
return paths
30+
31+
# Merge other parts of private_para that don't contain lists
32+
def merge_remaining(private_sub, merged_sub):
33+
for k, v in private_sub.items():
34+
if isinstance(v, dict):
35+
if k not in merged_sub or not isinstance(merged_sub[k], dict):
36+
merged_sub[k] = {}
37+
merge_remaining(v, merged_sub[k])
38+
elif not isinstance(v, list): # Skip lists as they are already processed
39+
merged_sub[k] = v
40+
41+
# Find all paths with list values in private_para
42+
list_paths = find_list_paths(private_para)
43+
44+
# Extract the lists and their corresponding paths
45+
keys = [path for path, values in list_paths]
46+
values_lists = [values for path, values in list_paths]
47+
48+
# Generate all possible combinations
49+
combinations = list(itertools.product(*values_lists))
50+
51+
for combo in combinations:
52+
# Start with a deep copy of public_para
53+
merged = copy.deepcopy(public_para)
54+
# Insert each value from the combination into the correct path in merged dict
55+
for path, value in zip(keys, combo):
56+
current_level = merged
57+
for key in path[:-1]:
58+
if key not in current_level or not isinstance(current_level[key], dict):
59+
current_level[key] = {}
60+
current_level = current_level[key]
61+
current_level[path[-1]] = value
62+
merge_remaining(private_para, merged)
63+
a_private_name = '_'.join([str(x) for x in combo])
64+
yield merged, a_private_name
65+
66+
67+
# patience * nsamples / batch_size > 2000
68+
# Here nsamples=2500
69+
def maintain_patience(merged_para_dict):
70+
patience = merged_para_dict["train_options"]["batch_size"]
71+
merged_para_dict["train_options"]["lr_scheduler"].update({"patience": patience})
72+
return merged_para_dict
73+
74+
75+
class DPTBDpdispatcher:
76+
def __init__(self,
77+
private_para_dict: dict,
78+
public_para_dict: dict,
79+
machine_info: dict,
80+
resrc_info: dict,
81+
cmd_line: str=fr'dptb train input.json -o ./output 2>&1 ',
82+
old_ckpt_path: str=None):
83+
self.private_para_dict = private_para_dict
84+
self.public_para_dict = public_para_dict
85+
self.machine_info = machine_info
86+
self.resrc_info = resrc_info
87+
self.workbase = os.getcwd()
88+
self.cmd_line = cmd_line
89+
if old_ckpt_path:
90+
self.old_ckpt_name = os.path.basename(old_ckpt_path)
91+
self.old_ckpt_path = os.path.abspath(old_ckpt_path)
92+
else:
93+
self.old_ckpt_path = None
94+
95+
def find_largest_event_file(self, directory='.'):
96+
largest_file = None
97+
largest_size = -1
98+
for file in os.listdir(directory):
99+
if file.startswith('events.out.tfevents'):
100+
file_path = os.path.join(directory, file)
101+
file_size = os.path.getsize(file_path)
102+
if file_size > largest_size:
103+
largest_file = file
104+
largest_size = file_size
105+
return largest_file
106+
107+
def prepare_workbase(self):
108+
self.task_list = []
109+
self.path_raw = os.path.abspath('raw')
110+
self.path_raw_job_paths = []
111+
os.makedirs(exist_ok=True, name=self.path_raw)
112+
113+
for a_merged_para, job_name in merge_parameters(private_para=self.private_para_dict, public_para=self.public_para_dict):
114+
os.chdir(self.path_raw)
115+
os.makedirs(job_name, exist_ok=True)
116+
os.chdir(job_name)
117+
self.path_raw_job_paths.append(os.getcwd())
118+
# a_merged_para = maintain_patience(a_merged_para)
119+
with open(r"input.json", 'w') as f:
120+
json.dump(a_merged_para, f, indent=4)
121+
if self.old_ckpt_path:
122+
shutil.copy(src=self.old_ckpt_path, dst=self.old_ckpt_name)
123+
a_task = Task(command=self.cmd_line,
124+
task_work_path=f'{job_name}/',
125+
forward_files=[f'{self.path_raw}/{job_name}/*'],
126+
backward_files=['tensorboard_logs/*', 'output/*'])
127+
self.task_list.append(a_task)
128+
129+
os.chdir(self.workbase)
130+
131+
def run_a_batch(self):
132+
machine = Machine.load_from_dict(machine_dict=self.machine_info)
133+
resources = Resources.load_from_dict(resources_dict=self.resrc_info)
134+
submission = Submission(work_base=f'{self.path_raw}',
135+
machine=machine,
136+
resources=resources,
137+
task_list=self.task_list,
138+
forward_common_files=[],
139+
backward_common_files=[]
140+
)
141+
submission.run_submission(check_interval=60, clean=True)
142+
143+
def post_process(self):
144+
self.path_cooked = os.path.abspath('cooked')
145+
if os.path.exists(self.path_cooked):
146+
shutil.rmtree(self.path_cooked)
147+
os.makedirs(name=self.path_cooked)
148+
os.chdir(self.path_cooked)
149+
os.makedirs('ckpt')
150+
self.ckpt_path = os.path.abspath('ckpt')
151+
os.makedirs('events')
152+
self.events_path = os.path.abspath('events')
153+
for a_job_path in self.path_raw_job_paths:
154+
a_job_name = os.path.split(a_job_path)[-1]
155+
os.chdir(a_job_path)
156+
os.chdir('tensorboard_logs')
157+
largest_file = self.find_largest_event_file()
158+
new_event_folder_path = os.path.join(self.events_path, a_job_name)
159+
os.makedirs(new_event_folder_path)
160+
shutil.copy(src=largest_file, dst=os.path.join(new_event_folder_path, largest_file))
161+
shutil.copy(src=os.path.join(a_job_path, 'output', 'checkpoint', 'nnenv.best.pth'),
162+
dst=os.path.join(self.ckpt_path, f'{a_job_name}.pth'))
163+
os.chdir(self.workbase)
164+
165+
def run_with_dpdispatcher(self):
166+
self.prepare_workbase()
167+
self.run_a_batch()
168+
self.post_process()

0 commit comments

Comments
 (0)