99
1010
1111class BasicDataset (Dataset ):
12- def __init__ (self , imgs_dir , masks_dir , scale = 1 ):
12+ def __init__ (self , imgs_dir , masks_dir , scale = 1 , mask_suffix = '' ):
1313 self .imgs_dir = imgs_dir
1414 self .masks_dir = masks_dir
1515 self .scale = scale
16+ self .mask_suffix = mask_suffix
1617 assert 0 < scale <= 1 , 'Scale must be between 0 and 1'
1718
1819 self .ids = [splitext (file )[0 ] for file in listdir (imgs_dir )
@@ -43,7 +44,7 @@ def preprocess(cls, pil_img, scale):
4344
4445 def __getitem__ (self , i ):
4546 idx = self .ids [i ]
46- mask_file = glob (self .masks_dir + idx + '.*' )
47+ mask_file = glob (self .masks_dir + idx + self . mask_suffix + '.*' )
4748 img_file = glob (self .imgs_dir + idx + '.*' )
4849
4950 assert len (mask_file ) == 1 , \
@@ -63,3 +64,8 @@ def __getitem__(self, i):
6364 'image' : torch .from_numpy (img ).type (torch .FloatTensor ),
6465 'mask' : torch .from_numpy (mask ).type (torch .FloatTensor )
6566 }
67+
68+
69+ class CarvanaDataset (BasicDataset ):
70+ def __init__ (self , imgs_dir , masks_dir , scale = 1 ):
71+ super ().__init__ (imgs_dir , masks_dir , scale , mask_suffix = '_mask' )
0 commit comments