|
1 | 1 | import asyncio |
2 | 2 | import logging |
| 3 | +import time |
| 4 | +from concurrent.futures import ThreadPoolExecutor |
| 5 | +from pathlib import Path |
| 6 | +from typing import Optional, cast |
3 | 7 |
|
4 | 8 | import av |
| 9 | +import av.filter |
| 10 | +import av.frame |
5 | 11 | from aiortc import VideoStreamTrack |
| 12 | +from av import VideoFrame |
6 | 13 | from PIL import Image |
7 | 14 | from vision_agents.core.utils.video_queue import VideoLatestNQueue |
8 | 15 |
|
@@ -88,3 +95,118 @@ def stop(self): |
88 | 95 | @property |
89 | 96 | def stopped(self) -> bool: |
90 | 97 | return self._stopped |
| 98 | + |
| 99 | + |
| 100 | +class VideoFileTrack(VideoStreamTrack): |
| 101 | + """ |
| 102 | + A video track reading from a local MP4 file, |
| 103 | + filtered to a constant FPS using FFmpeg (30 FPS by default). |
| 104 | +
|
| 105 | + Use it for testing and debugging. |
| 106 | + """ |
| 107 | + |
| 108 | + def __init__(self, path: str | Path, fps: int = 30): |
| 109 | + super().__init__() |
| 110 | + self.fps = fps |
| 111 | + self.path = Path(path) |
| 112 | + |
| 113 | + self._stopped = False |
| 114 | + self._container = av.open(path) |
| 115 | + self._stream = self._container.streams.video[0] |
| 116 | + if self._stream.time_base is None: |
| 117 | + raise ValueError("Cannot determine time_base for the video stream") |
| 118 | + |
| 119 | + self._time_base = self._stream.time_base |
| 120 | + |
| 121 | + # Decoder iterator to read the frames |
| 122 | + self._decoder = self._container.decode(self._stream) |
| 123 | + self._executor = ThreadPoolExecutor(1) |
| 124 | + self._set_filter_graph() |
| 125 | + |
| 126 | + def _set_filter_graph(self): |
| 127 | + # Safe extraction of sample_aspect_ratio |
| 128 | + sar = self._stream.sample_aspect_ratio |
| 129 | + if sar is None: |
| 130 | + sar_num, sar_den = 1, 1 |
| 131 | + else: |
| 132 | + sar_num, sar_den = sar.numerator, sar.denominator |
| 133 | + |
| 134 | + # Build ffmpeg filter graph to resample video to fixed fps |
| 135 | + # Keep the reference to the graph to avoid GC |
| 136 | + self._graph = av.filter.Graph() |
| 137 | + # Buffer source with all required parameters |
| 138 | + |
| 139 | + self._src = self._graph.add( |
| 140 | + "buffer", |
| 141 | + f"video_size={self._stream.width}x{self._stream.height}:" |
| 142 | + f"pix_fmt={self._stream.pix_fmt}:" |
| 143 | + f"time_base={self._time_base.numerator}/{self._time_base.denominator}:" |
| 144 | + f"pixel_aspect={sar_num}/{sar_den}", |
| 145 | + ) |
| 146 | + |
| 147 | + # Add an FPS filter |
| 148 | + fps_filter = self._graph.add("fps", f"fps={self.fps}") |
| 149 | + |
| 150 | + # Add a buffer sink |
| 151 | + self._sink = self._graph.add("buffersink") |
| 152 | + |
| 153 | + # Connect graph: buffer -> fps filter -> sink |
| 154 | + self._src.link_to(fps_filter) |
| 155 | + fps_filter.link_to(self._sink) |
| 156 | + self._graph.configure() |
| 157 | + |
| 158 | + def _next_frame(self) -> av.VideoFrame: |
| 159 | + filtered_frame: Optional[av.VideoFrame] = None |
| 160 | + while filtered_frame is None: |
| 161 | + # Get the next decoded frame |
| 162 | + try: |
| 163 | + frame = next(self._decoder) |
| 164 | + except StopIteration: |
| 165 | + # Loop the video when it ends |
| 166 | + self._container.seek(0) |
| 167 | + self._decoder = self._container.decode(self._stream) |
| 168 | + # Reset the filter graph too |
| 169 | + self._set_filter_graph() |
| 170 | + frame = next(self._decoder) |
| 171 | + |
| 172 | + # Ensure frame has a time_base (required by buffer source) |
| 173 | + frame.time_base = self._time_base |
| 174 | + |
| 175 | + # Push decoded frame into the filter graph |
| 176 | + self._src.push(frame) |
| 177 | + |
| 178 | + # Pull filtered frame from buffersink |
| 179 | + try: |
| 180 | + filtered_frame = cast(av.VideoFrame, self._sink.pull()) |
| 181 | + except (av.ExitError, av.BlockingIOError): |
| 182 | + # Filter graph is not ready to output yet |
| 183 | + time.sleep(0.001) |
| 184 | + continue |
| 185 | + except Exception: |
| 186 | + logger.exception("Failed to read a frame from video file") |
| 187 | + continue |
| 188 | + |
| 189 | + # Convert the filtered video frame to RGB for aiortc |
| 190 | + new_frame = filtered_frame.to_rgb() |
| 191 | + |
| 192 | + return new_frame |
| 193 | + |
| 194 | + async def recv(self) -> VideoFrame: |
| 195 | + """ |
| 196 | + Async method to produce the next filtered video frame. |
| 197 | + Loops automatically at the end of the file. |
| 198 | + """ |
| 199 | + if self._stopped: |
| 200 | + raise VideoTrackClosedError("Track stopped") |
| 201 | + loop = asyncio.get_running_loop() |
| 202 | + frame = await loop.run_in_executor(self._executor, self._next_frame) |
| 203 | + # Sleep between frames to let other coroutines to run |
| 204 | + await asyncio.sleep(float(frame.time_base)) |
| 205 | + return frame |
| 206 | + |
| 207 | + def stop(self) -> None: |
| 208 | + self._stopped = True |
| 209 | + self._executor.shutdown(wait=False) |
| 210 | + |
| 211 | + def __repr__(self): |
| 212 | + return f'<{self.__class__.__name__} path="{self.path}" fps={self.fps}>' |
0 commit comments