Skip to content

Commit d47386c

Browse files
committed
allow for more than 3 way split
1 parent 584e6e0 commit d47386c

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

label_maker/package.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -97,36 +97,31 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, split_
9797
elif ml_type == 'segmentation':
9898
y_vals.append(labels[tile][..., np.newaxis]) # Add grayscale channel
9999

100-
# convert lists to numpy arrays
100+
# Convert lists to numpy arrays
101101
x_vals = np.array(x_vals, dtype=np.uint8)
102+
print(x_vals.shape)
102103
y_vals = np.array(y_vals, dtype=np.uint8)
104+
print(y_vals).shape
103105

104-
x_vals_split_lst = np.split(x_vals,
105-
[int(split_vals[0] * len(x_vals)), int((split_vals[0] + split_vals[1]) * len(x_vals))])
106+
# Get number of data samples per split from the float proportions
107+
split_n_samps = np.rint([len(x_vals) * val for val in split_vals])
108+
print(split_n_samps)
106109

107-
if len(x_vals_split_lst[-1]) == 0:
108-
x_vals_split_lst = x_vals_split_lst[:-1]
110+
if np.any(split_n_samps == 0):
111+
raise ValueError
109112

110-
y_vals_split_lst = np.split(y_vals,
111-
[int(split_vals[0] * len(y_vals)), int((split_vals[0] + split_vals[1]) * len(y_vals))])
113+
# Convert into a cumulative sum to get indices
114+
split_inds = np.cumsum(split_n_samps).astype(np.integer)
112115

113-
if len(y_vals_split_lst[-1]) == 0:
114-
y_vals_split_lst = y_vals_split_lst[:-1]
116+
# Exclude last index as `np.split` handles splitting without that value
117+
split_arrs_x = np.split(x_vals, split_inds[:-1])
118+
split_arrs_y = np.split(y_vals, split_inds[:-1])
115119

116-
print('Saving packaged file to {}'.format(op.join(dest_folder, 'data.npz')))
120+
save_dict = {}
117121

118-
if len(split_vals) == 2:
119-
np.savez(op.join(dest_folder, 'data.npz'),
120-
x_train=x_vals_split_lst[0],
121-
y_train=y_vals_split_lst[0],
122-
x_test=x_vals_split_lst[1],
123-
y_test=y_vals_split_lst[1])
122+
for si, split_name in enumerate(split_names):
123+
save_dict[f'x_{split_name}'] = split_arrs_x[si]
124+
save_dict[f'y_{split_name}'] = split_arrs_y[si]
124125

125-
if len(split_vals) == 3:
126-
np.savez(op.join(dest_folder, 'data.npz'),
127-
x_train=x_vals_split_lst[0],
128-
y_train=y_vals_split_lst[0],
129-
x_test=x_vals_split_lst[1],
130-
y_test=y_vals_split_lst[1],
131-
x_val=x_vals_split_lst[2],
132-
y_val=y_vals_split_lst[2])
126+
np.savez(op.join(dest_folder, 'data.npz'), **save_dict)
127+
print('Saving packaged file to {}'.format(op.join(dest_folder, 'data.npz')))

0 commit comments

Comments
 (0)