Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/Detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ TESTPATH
predict.py [-h] -m MODELPATH -c TESTPATH

```
where MODELPATH is the path to the saved model weights after training and TESTPATH is the directory containing your pre-processed orca images. These spectrogram images are first **renamed from 0 to N-1** to help find the start and end time of orca calls. The images are then fed to the model and the predicted orca samples are saved in a new folder `pos_orca` within the same directory.
where MODELPATH is the file path to the saved model weights or a directory path to multiple saved model weights after training and TESTPATH is the directory containing your pre-processed orca images. These spectrogram images are first **renamed from 0 to N-1** to help find the start and end time of orca calls. The images are then fed to the model and the predicted orca samples are saved in a new folder `pos_orca` within the same directory. A csv file named `predictions.csv` labelling which examples were predicted with high confidence and low confidence by the model(s) will also be saved to the `pos_orca` folder.

- Determine start and end time of orca in the audio sample:

Expand Down
56 changes: 41 additions & 15 deletions app/Detection/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import logging
import argparse
import csv
from itertools import zip_longest
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
Expand All @@ -14,19 +16,11 @@
logger = logging.getLogger(__name__)


def predict(model_path, test_path):
folder_path = test_path
model_path = model_path

img_width, img_height = 200, 300

model = load_model(model_path)

# images = []

def generate_predictions(model_path, folder_path):
'''
Stacking images eats up RAM much faster: https://hjweide.github.io/efficient-image-loading
'''
img_width, img_height = 200, 300
N = sum(len(files) for _, _, files in os.walk(folder_path))
data = np.empty((N, img_width, img_height, 3), dtype=np.uint8)

Expand All @@ -39,16 +33,38 @@ def predict(model_path, test_path):
# images.append(img)
data[i, ...] = img

if os.path.isfile(model_path):
model = load_model(model_path)
classes = (model.predict(data) > 0.5).astype("int32")
elif os.path.isdir(model_path):
scores = np.zeros((N, 1))
for model_name in os.listdir(model_path):
model = load_model(os.path.join(model_path, model_name))
scores += (model.predict(data) > 0.5).astype("int32")
classes = scores / len(os.listdir(model_path))
classes[classes < 0.5] = 0
classes[classes >= 0.5] = 1
classes = classes.astype("int32")
else:
raise ValueError(f"Path provided, {model_path}, is not a valid path!")

return classes


def predict(model_path, test_path):
folder_path = test_path
model_path = model_path

logger.info("Starting Prediction")

classes = generate_predictions(model_path, folder_path)

# stack up images list to pass for prediction
# images = []
# images = np.vstack(images)

# classes = model.predict_classes(data)

classes = (model.predict(data) > 0.5).astype("int32")

f = []

f, pos, negs = [], [], []
for i in os.listdir(folder_path):
f.append(i)

Expand All @@ -57,7 +73,17 @@ def predict(model_path, test_path):

os.makedirs("pos_orca", exist_ok=True)
if classes[i][0] == 1:
pos.append(f[i])
shutil.copy(f_n, 'pos_orca')
else:
negs.append(f[i])

with open(os.path.join("pos_orca", "predictions.csv"), "w+") as f:
writer = csv.writer(f)
writer.writerow(["high_confidence", "low_confidence"])
for values in zip_longest(*[pos, negs]):
writer.writerow(values)

logger.info(
f"Detected {sum(len(files) for _, _, files in os.walk('pos_orca'))} orca calls")

Expand Down