@@ -97,36 +97,31 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, split_
9797 elif ml_type == 'segmentation' :
9898 y_vals .append (labels [tile ][..., np .newaxis ]) # Add grayscale channel
9999
100- # convert lists to numpy arrays
100+ # Convert lists to numpy arrays
101101 x_vals = np .array (x_vals , dtype = np .uint8 )
102+ print (x_vals .shape )
102103 y_vals = np .array (y_vals , dtype = np .uint8 )
104+ print (y_vals ).shape
103105
104- x_vals_split_lst = np .split (x_vals ,
105- [int (split_vals [0 ] * len (x_vals )), int ((split_vals [0 ] + split_vals [1 ]) * len (x_vals ))])
106+ # Get number of data samples per split from the float proportions
107+ split_n_samps = np .rint ([len (x_vals ) * val for val in split_vals ])
108+ print (split_n_samps )
106109
107- if len ( x_vals_split_lst [ - 1 ]) == 0 :
108- x_vals_split_lst = x_vals_split_lst [: - 1 ]
110+ if np . any ( split_n_samps == 0 ) :
111+ raise ValueError
109112
110- y_vals_split_lst = np . split ( y_vals ,
111- [ int ( split_vals [ 0 ] * len ( y_vals )), int (( split_vals [ 0 ] + split_vals [ 1 ]) * len ( y_vals ))] )
113+ # Convert into a cumulative sum to get indices
114+ split_inds = np . cumsum ( split_n_samps ). astype ( np . integer )
112115
113- if len (y_vals_split_lst [- 1 ]) == 0 :
114- y_vals_split_lst = y_vals_split_lst [:- 1 ]
116+ # Exclude last index as `np.split` handles splitting without that value
117+ split_arrs_x = np .split (x_vals , split_inds [:- 1 ])
118+ split_arrs_y = np .split (y_vals , split_inds [:- 1 ])
115119
116- print ( 'Saving packaged file to {}' . format ( op . join ( dest_folder , 'data.npz' )))
120+ save_dict = {}
117121
118- if len (split_vals ) == 2 :
119- np .savez (op .join (dest_folder , 'data.npz' ),
120- x_train = x_vals_split_lst [0 ],
121- y_train = y_vals_split_lst [0 ],
122- x_test = x_vals_split_lst [1 ],
123- y_test = y_vals_split_lst [1 ])
122+ for si , split_name in enumerate (split_names ):
123+ save_dict [f'x_{ split_name } ' ] = split_arrs_x [si ]
124+ save_dict [f'y_{ split_name } ' ] = split_arrs_y [si ]
124125
125- if len (split_vals ) == 3 :
126- np .savez (op .join (dest_folder , 'data.npz' ),
127- x_train = x_vals_split_lst [0 ],
128- y_train = y_vals_split_lst [0 ],
129- x_test = x_vals_split_lst [1 ],
130- y_test = y_vals_split_lst [1 ],
131- x_val = x_vals_split_lst [2 ],
132- y_val = y_vals_split_lst [2 ])
126+ np .savez (op .join (dest_folder , 'data.npz' ), ** save_dict )
127+ print ('Saving packaged file to {}' .format (op .join (dest_folder , 'data.npz' )))
0 commit comments