Skip to content

Commit ce2d7bf

Browse files
committed
updates
1 parent d47386c commit ce2d7bf

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

label_maker/package.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,14 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, split_
9999

100100
# Convert lists to numpy arrays
101101
x_vals = np.array(x_vals, dtype=np.uint8)
102-
print(x_vals.shape)
103102
y_vals = np.array(y_vals, dtype=np.uint8)
104-
print(y_vals).shape
105103

106104
# Get number of data samples per split from the float proportions
107105
split_n_samps = np.rint([len(x_vals) * val for val in split_vals])
108-
print(split_n_samps)
106+
#print(split_n_samps)
109107

110108
if np.any(split_n_samps == 0):
111-
raise ValueError
109+
raise ValueError('split must not generate zero samples per partition, change ratio of values in config file.')
112110

113111
# Convert into a cumulative sum to get indices
114112
split_inds = np.cumsum(split_n_samps).astype(np.integer)

test/fixtures/integration/config_3way.integration.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919
"ml_type": "classification",
2020
"seed": 19,
2121
"split_names": ["train", "test", "val"],
22-
"split_vals": [0.7, 0.2, 0.1]
22+
"split_vals": [0.6, 0.2, 0.2]
2323
}

test/integration/test_classification_package.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import numpy as np
99

10+
1011
class TestClassificationPackage(unittest.TestCase):
1112
"""Tests for classification package creation"""
13+
1214
@classmethod
1315
def setUpClass(cls):
14-
1516
makedirs('integration-cl')
1617
copyfile('test/fixtures/integration/labels-cl.npz', 'integration-cl/labels.npz')
1718
copytree('test/fixtures/integration/tiles', 'integration-cl/tiles')
@@ -58,7 +59,8 @@ def test_cli(self):
5859
def test_cli_3way_split(self):
5960
"""Verify data.npz produced by CLI when split into train/test/val"""
6061

61-
cmd = 'label-maker package --dest integration-cl-split --config test/fixtures/integration/config_3way.integration.json'
62+
cmd = 'label-maker package --dest integration-cl-split --config ' \
63+
'test/fixtures/integration/config_3way.integration.json '
6264
cmd = cmd.split(' ')
6365
subprocess.run(cmd, universal_newlines=True)
6466

@@ -72,4 +74,4 @@ def test_cli_3way_split(self):
7274
# validate label data with shapes
7375
self.assertEqual(data['y_train'].shape, (5, 7))
7476
self.assertEqual(data['y_test'].shape, (2, 7))
75-
self.assertEqual(data['y_val'].shape, (1, 7))
77+
self.assertEqual(data['y_val'].shape, (1, 7))

0 commit comments

Comments
 (0)