diff --git a/pyroengine/engine.py b/pyroengine/engine.py index a0b009a4..015aa55d 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -22,7 +22,7 @@ from requests.exceptions import ConnectionError from requests.models import Response -from pyroengine.utils import box_iou, nms +from pyroengine.utils import box_iou, nms, multi_resolution_frame from .vision import Classifier @@ -90,7 +90,7 @@ def __init__( cache_size: int = 100, cache_folder: str = "data/", backup_size: int = 30, - jpeg_quality: int = 80, + avif_quality: int = 50, day_time_strategy: Optional[str] = None, save_captured_frames: Optional[bool] = False, send_last_image_period: int = 3600, # 1H @@ -118,7 +118,7 @@ def __init__( self.frame_saving_period = frame_saving_period self.nb_consecutive_frames = nb_consecutive_frames self.frame_size = frame_size - self.jpeg_quality = jpeg_quality + self.avif_quality = avif_quality self.cache_backup_period = cache_backup_period self.day_time_strategy = day_time_strategy self.save_captured_frames = save_captured_frames @@ -168,7 +168,7 @@ def heartbeat(self, cam_id: str) -> Response: ip = cam_id.split("_")[0] return self.api_client[ip].heartbeat() - def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> float: + def _update_states(self, input_frame: Image.Image, frame: Image.Image, preds: np.ndarray, cam_key: str) -> float | Image.Image: prev_ongoing = self._states[cam_key]["ongoing"] conf_th = self.conf_thresh * self.nb_consecutive_frames @@ -233,6 +233,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> if output_predictions.size > 0: output_predictions = np.atleast_2d(output_predictions) + frame = multi_resolution_frame(input_frame, frame, output_predictions.tolist()) self._states[cam_key]["last_predictions"].append(( frame, preds, @@ -252,7 +253,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> self._states[cam_key]["miss_count"] = 0 self._states[cam_key]["ongoing"] = new_ongoing - return conf + return conf, frame def predict( self, frame: Image.Image, cam_id: Optional[str] = None, fake_pred: Optional[np.ndarray] = None @@ -272,27 +273,11 @@ def predict( the predicted confidence """ cam_key = cam_id or "-1" + input_frame = frame.copy() # Reduce image size to save bandwidth if isinstance(self.frame_size, tuple): frame = frame.resize(self.frame_size[::-1], Image.BILINEAR) # type: ignore[attr-defined] - # Heartbeat - if len(self.api_client) > 0 and isinstance(cam_id, str): - heartbeat_with_timeout(self, cam_id, timeout=1) - if ( - self._states[cam_key]["last_image_sent"] is None - or time.time() - self._states[cam_key]["last_image_sent"] > self.send_last_image_period - ): - # send image periodically - logging.info(f"Uploading periodical image for cam {cam_id}") - self._states[cam_key]["last_image_sent"] = time.time() - ip = cam_id.split("_")[0] - if ip in self.api_client.keys(): - stream = io.BytesIO() - frame.save(stream, format="JPEG", quality=self.jpeg_quality) - response = self.api_client[ip].update_last_image(stream.getvalue()) - logging.info(response.text) - # Update occlusion masks if ( self._states[cam_key]["last_bbox_mask_fetch"] is None @@ -326,11 +311,27 @@ def predict( preds = np.reshape(preds, (-1, 5)) logging.info(f"pred for {cam_key} : {preds}") - conf = self._update_states(frame, preds, cam_key) - + conf, frame = self._update_states(input_frame, frame, preds, cam_key) if self.save_captured_frames: self._local_backup(frame, cam_id, is_alert=False) + # Heartbeat + if len(self.api_client) > 0 and isinstance(cam_id, str): + heartbeat_with_timeout(self, cam_id, timeout=1) + if ( + self._states[cam_key]["last_image_sent"] is None + or time.time() - self._states[cam_key]["last_image_sent"] > self.send_last_image_period + ): + # send image periodically + logging.info(f"Uploading periodical image for cam {cam_id}") + self._states[cam_key]["last_image_sent"] = time.time() + ip = cam_id.split("_")[0] + if ip in self.api_client.keys(): + stream = io.BytesIO() + frame.save(stream, format="avif", quality=self.avif_quality) + response = self.api_client[ip].update_last_image(stream.getvalue()) + logging.info(response.text) + # Log analysis result device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else "" pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire" @@ -397,7 +398,7 @@ def _process_alerts(self) -> None: try: # Detection creation stream = io.BytesIO() - frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality) + frame_info["frame"].save(stream, format="avif", quality=self.avif_quality) bboxes = self._alerts[0]["bboxes"] bboxes = [tuple(bboxe) for bboxe in bboxes] _, cam_azimuth, _ = self.cam_creds[cam_id] diff --git a/pyroengine/utils.py b/pyroengine/utils.py index b1cf6ae3..d2564bac 100644 --- a/pyroengine/utils.py +++ b/pyroengine/utils.py @@ -7,6 +7,7 @@ import cv2 import numpy as np from tqdm import tqdm +from PIL import Image __all__ = ["DownloadProgressBar", "letterbox", "nms", "xywh2xyxy"] @@ -108,6 +109,28 @@ def nms(boxes: np.ndarray, overlapThresh: int = 0): return boxes[indices] +def multi_resolution_frame(high_resolution_frame: Image.Image, low_resolution_frame: Image.Image, bboxes: list) -> Image.Image: + """ + Creates a multi_resolution_frame image that has the size of high_resolution_frame. + The bboxes parts of the image are retrieved from the high_resolution_frame. + The rest of the image is retrieved from the low_resolution_frame. + The goal is to have a light image with high resolution only on the bboxes parts. + + Args: + high_resolution_frame (Image.Image): The initial image that has high resolution. + low_resolution_frame (Image.Image): The already resized image that has lower resolution but weighs less. + bboxes (List): A list of list that contains the relative coordinate in order xmin, ymin, xmax, ymax, conf + + Returns: + (Image.Image): The multi_resolution_frame image. + """ + high_res_width, high_res_height = high_resolution_frame.size + result_frame = low_resolution_frame.resize((high_res_width, high_res_height), Image.BILINEAR) + + for bbox in bboxes: + high_res_bbox = (round(bbox[0] * high_res_width), round(bbox[1] * high_res_height), round(bbox[2] * high_res_width), round(bbox[3] * high_res_height)) + result_frame.paste(high_resolution_frame.crop(high_res_bbox), (high_res_bbox[0], high_res_bbox[1])) + return result_frame class DownloadProgressBar(tqdm): def update_to(self, b=1, bsize=1, tsize=None): diff --git a/requirements.txt b/requirements.txt index f665c286..6fa7db8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ numpy==2.2.6 ; python_version >= "3.11" and python_version < "4.0" onnxruntime==1.22.1 ; python_version >= "3.11" and python_version < "4.0" opencv-python==4.12.0.88 ; python_version >= "3.11" and python_version < "4.0" packaging==25.0 ; python_version >= "3.11" and python_version < "4.0" -pillow==11.0.0 ; python_version >= "3.11" and python_version < "4.0" +pillow==12.1.0 ; python_version >= "3.11" and python_version < "4.0" portalocker==3.2.0 ; python_version >= "3.11" and python_version < "4.0" protobuf==6.33.1 ; python_version >= "3.11" and python_version < "4.0" pyreadline3==3.5.4 ; python_version >= "3.11" and python_version < "4.0" and sys_platform == "win32" diff --git a/src/run.py b/src/run.py index f7dacf16..1eb65458 100644 --- a/src/run.py +++ b/src/run.py @@ -62,7 +62,7 @@ def main(args): frame_size=args.frame_size, cache_backup_period=args.cache_backup_period, cache_size=args.cache_size, - jpeg_quality=args.jpeg_quality, + avif_quality=args.avif_quality, day_time_strategy=args.day_time_strategy, save_captured_frames=args.save_captured_frames, ) @@ -93,7 +93,7 @@ def main(args): default=(720, 1280), help="Resize frame to frame_size before sending it to the api in order to save bandwidth (H, W)", ) - parser.add_argument("--jpeg_quality", type=int, default=80, help="Jpeg compression") + parser.add_argument("--avif_quality", type=int, default=50, help="avif compression") parser.add_argument( "--cache-size", type=int,