Skip to content

Commit 55b6156

Browse files
authored
Update video data loading example (#859)
1 parent f639854 commit 55b6156

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

examples/video_dataloading.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@
5151
CPU decoding with higher concurrency often yields higher throughput.
5252
"""
5353

54-
# pyre-ignore-all-errors
54+
# pyre-strict
5555

56+
import argparse
5657
import logging
5758
import signal
5859
import time
60+
from argparse import Namespace
5961
from collections.abc import Callable, Iterable
6062
from dataclasses import dataclass
6163
from pathlib import Path
@@ -67,7 +69,7 @@
6769
from spdl.pipeline import Pipeline, PipelineBuilder
6870
from torch import Tensor
6971

70-
_LG = logging.getLogger(__name__)
72+
_LG: logging.Logger = logging.getLogger(__name__)
7173

7274
__all__ = [
7375
"entrypoint",
@@ -81,9 +83,7 @@
8183
]
8284

8385

84-
def _parse_args(args):
85-
import argparse
86-
86+
def _parse_args(args: list[str]) -> Namespace:
8787
parser = argparse.ArgumentParser(
8888
description=__doc__,
8989
)
@@ -97,10 +97,10 @@ def _parse_args(args):
9797
parser.add_argument("--worker-id", type=int, required=True)
9898
parser.add_argument("--num-workers", type=int, required=True)
9999
parser.add_argument("--nvdec", action="store_true")
100-
args = parser.parse_args(args)
101-
if args.trace:
102-
args.max_samples = 320
103-
return args
100+
ns = parser.parse_args(args)
101+
if ns.trace:
102+
ns.max_samples = 320
103+
return ns
104104

105105

106106
def source(
@@ -181,7 +181,7 @@ def decode_video_nvdec(
181181
device_index: int,
182182
width: int,
183183
height: int,
184-
):
184+
) -> Tensor:
185185
"""Decode video using NVDEC.
186186
187187
Args:
@@ -211,15 +211,17 @@ def decode_video_nvdec(
211211
return spdl.io.to_torch(buffer)[..., :3].permute(0, 2, 3, 1)
212212

213213

214-
def _get_decode_fn(device_index, use_nvdec, width=222, height=222):
214+
def _get_decode_fn(
215+
device_index: int, use_nvdec: bool, width: int = 222, height: int = 222
216+
) -> Callable[[str], Tensor]:
215217
if use_nvdec:
216218

217-
def _decode_func(src):
219+
def _decode_func(src: str) -> Tensor:
218220
return decode_video_nvdec(src, device_index, width, height)
219221

220222
else:
221223

222-
def _decode_func(src):
224+
def _decode_func(src: str) -> Tensor:
223225
return decode_video(src, width, height, device_index)
224226

225227
return _decode_func
@@ -250,7 +252,7 @@ def get_pipeline(
250252
)
251253

252254

253-
def _get_pipeline(args):
255+
def _get_pipeline(args: Namespace) -> Pipeline:
254256
src = source(
255257
input_flist=args.input_flist,
256258
prefix=args.prefix,
@@ -317,13 +319,13 @@ def benchmark(
317319
return PerfResult(elapsed, num_batches, num_frames)
318320

319321

320-
def worker_entrypoint(args: list[str]) -> PerfResult:
322+
def worker_entrypoint(args_: list[str]) -> PerfResult:
321323
"""Entrypoint for worker process. Load images to a GPU and measure its performance.
322324
323325
It builds a Pipeline object using :py:func:`get_pipeline` function and run it with
324326
:py:func:`benchmark` function.
325327
"""
326-
args = _parse_args(args)
328+
args = _parse_args(args_)
327329
_init(args.debug, args.worker_id)
328330

329331
_LG.info(args)
@@ -332,9 +334,9 @@ def worker_entrypoint(args: list[str]) -> PerfResult:
332334

333335
device = torch.device(f"cuda:{args.worker_id}")
334336

335-
ev = Event()
337+
ev: Event = Event()
336338

337-
def handler_stop_signals(_signum, _frame):
339+
def handler_stop_signals(_signum, _frame) -> None:
338340
ev.set()
339341

340342
signal.signal(signal.SIGTERM, handler_stop_signals)
@@ -350,29 +352,27 @@ def handler_stop_signals(_signum, _frame):
350352
return benchmark(pipeline.get_iterator(), ev)
351353

352354

353-
def _init_logging(debug=False, worker_id=None):
355+
def _init_logging(debug: bool = False, worker_id: int | None = None) -> None:
354356
fmt = "%(asctime)s [%(levelname)s] %(message)s"
355357
if worker_id is not None:
356358
fmt = f"[{worker_id}:%(thread)d] {fmt}"
357359
level = logging.DEBUG if debug else logging.INFO
358360
logging.basicConfig(format=fmt, level=level)
359361

360362

361-
def _init(debug, worker_id):
363+
def _init(debug: bool, worker_id: int) -> None:
362364
_init_logging(debug, worker_id)
363365

364366

365-
def _parse_process_args(args):
366-
import argparse
367-
367+
def _parse_process_args(args: list[str] | None) -> tuple[Namespace, list[str]]:
368368
parser = argparse.ArgumentParser(
369369
description=__doc__,
370370
)
371371
parser.add_argument("--num-workers", type=int, default=8)
372372
return parser.parse_known_args(args)
373373

374374

375-
def entrypoint(args: list[str] | None = None):
375+
def entrypoint(args: list[str] | None = None) -> None:
376376
"""CLI entrypoint. Launch the worker processes, each of which load videos and send them to GPU."""
377377
ns, args = _parse_process_args(args)
378378

0 commit comments

Comments
 (0)