Skip to content
This repository was archived by the owner on Mar 14, 2025. It is now read-only.

Commit 2c49b84

Browse files
committed
add SimpleCalibrator which should support dynamic shape int8 calibration
1 parent 8b98524 commit 2c49b84

File tree

2 files changed

+113
-15
lines changed

2 files changed

+113
-15
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import logging
3+
import numpy as np
4+
import tensorrt as trt
5+
import pycuda.driver as cuda
6+
import pycuda.autoinit
7+
8+
logging.basicConfig(level=logging.DEBUG,
9+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
10+
datefmt="%Y-%m-%d %H:%M:%S")
11+
logger = logging.getLogger(__name__)
12+
13+
class SimpleCalibrator(trt.IInt8EntropyCalibrator2):
14+
def __init__(self, network, config):
15+
super().__init__()
16+
17+
# TODO: Not sure of difference between get_batch_size and what's returned in get_batch ?
18+
# Notes:
19+
# get_batch_size() is required to return non-null value
20+
# get_batch_size() can return 0 with seemingly no consequence with/without calibration cache
21+
# get_batch_size() can return -1 with seemingly no consequence with/without calibration cache
22+
# get_batch() seems to do the work, as long as get_batch_size doesn't throw an error
23+
self.batch_size = -1
24+
self.shapes = []
25+
self.device_inputs = None
26+
num_calibration_samples = 1000
27+
self.iterator = (i for i in range(num_calibration_samples))
28+
self.cache_file = "simple_calibration.cache"
29+
self.network = network
30+
self.calib_profile = config.get_calibration_profile()
31+
32+
def get_batch(self, input_names, p_str=None):
33+
try:
34+
# Use iterator here to avoid having to pass input names to constructor
35+
next(self.iterator)
36+
if not self.shapes:
37+
self.set_shapes(input_names)
38+
39+
if not self.device_inputs:
40+
self.device_inputs = [cuda.mem_alloc(np.zeros(s, dtype=np.float32).nbytes) for s in self.shapes]
41+
42+
if not self.batch_size:
43+
# Get batch size from first input in calibration shapes. Assumes batch sizes
44+
# are the same for every input
45+
self.batch_size = self.shapes[0][0]
46+
47+
batches = [np.random.random(s).astype(np.float32) for s in self.shapes]
48+
for i in range(len(batches)):
49+
cuda.memcpy_htod(self.device_inputs[i], batches[i])
50+
51+
return [int(d) for d in self.device_inputs]
52+
except StopIteration:
53+
return None
54+
55+
def get_batch_size(self):
56+
return self.batch_size
57+
58+
def set_shapes(self, input_names):
59+
if self.calib_profile:
60+
self.shapes = [self.calib_profile.get_shape(name) for name in input_names]
61+
else:
62+
self.shapes = []
63+
# This assumes order of input_names matches the network input indices
64+
for i, name in enumerate(input_names):
65+
shape = self.network.get_input(i).shape
66+
_shape = []
67+
found_dynamic = False
68+
# Replace any dynamic dimensions with ones if any
69+
for dim in shape:
70+
if dim < 0:
71+
dim = 1
72+
found_dynamic = True
73+
74+
_shape.append(dim)
75+
76+
_shape = tuple(_shape)
77+
if found_dynamic:
78+
logger.warning("[{}] has dynamic shape: {}. Set to {} instead.".format(name, shape, _shape))
79+
80+
self.shapes.append(_shape)
81+
82+
def read_calibration_cache(self):
83+
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
84+
if os.path.exists(self.cache_file):
85+
with open(self.cache_file, "rb") as f:
86+
logger.info("Using calibration cache to save time: {:}".format(self.cache_file))
87+
return f.read()
88+
89+
def write_calibration_cache(self, cache):
90+
with open(self.cache_file, "wb") as f:
91+
logger.info("Caching calibration data for future use: {:}".format(self.cache_file))
92+
f.write(cache)

int8/calibration/onnx_to_tensorrt.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
import tensorrt as trt
2525

26-
from ImagenetCalibrator import ImagenetCalibrator, get_calibration_files, get_int8_calibrator # local module
27-
2826
TRT_LOGGER = trt.Logger()
2927
logging.basicConfig(level=logging.DEBUG,
3028
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -126,6 +124,7 @@ def main():
126124
parser.add_argument("--calibration-batch-size", help="(INT8 ONLY) The batch size to use during calibration.", type=int, default=32)
127125
parser.add_argument("--max-calibration-size", help="(INT8 ONLY) The max number of data to calibrate on from --calibration-data.", type=int, default=512)
128126
parser.add_argument("-p", "--preprocess_func", type=str, default=None, help="(INT8 ONLY) Function defined in 'processing.py' to use for pre-processing calibration data.")
127+
parser.add_argument("-s", "--simple", action="store_true", help="Use SimpleCalibrator with random data instead of ImagenetCalibrator for INT8 calibration.")
129128
args, _ = parser.parse_known_args()
130129

131130
# Adjust logging verbosity
@@ -169,19 +168,6 @@ def main():
169168
logger.info("Setting {}".format(builder_flag_map[flag]))
170169
config.set_flag(builder_flag_map[flag])
171170

172-
if args.fp16 and not builder.platform_has_fast_fp16:
173-
logger.warning("FP16 not supported on this platform.")
174-
175-
if args.int8 and not builder.platform_has_fast_int8:
176-
logger.warning("INT8 not supported on this platform.")
177-
178-
if args.int8:
179-
config.int8_calibrator = get_int8_calibrator(args.calibration_cache,
180-
args.calibration_data,
181-
args.max_calibration_size,
182-
args.preprocess_func,
183-
args.calibration_batch_size)
184-
185171
# Fill network atrributes with information by parsing model
186172
with open(args.onnx, "rb") as f:
187173
if not parser.parse(f.read()):
@@ -202,6 +188,26 @@ def main():
202188
# Implicit Batch Network
203189
else:
204190
builder.max_batch_size = args.max_batch_size
191+
opt_profiles = []
192+
193+
# Precision flags
194+
if args.fp16 and not builder.platform_has_fast_fp16:
195+
logger.warning("FP16 not supported on this platform.")
196+
197+
if args.int8 and not builder.platform_has_fast_int8:
198+
logger.warning("INT8 not supported on this platform.")
199+
200+
if args.int8:
201+
if args.simple:
202+
from SimpleCalibrator import SimpleCalibrator # local module
203+
config.int8_calibrator = SimpleCalibrator(network, config)
204+
else:
205+
from ImagenetCalibrator import ImagenetCalibrator, get_int8_calibrator # local module
206+
config.int8_calibrator = get_int8_calibrator(args.calibration_cache,
207+
args.calibration_data,
208+
args.max_calibration_size,
209+
args.preprocess_func,
210+
args.calibration_batch_size)
205211

206212
logger.info("Building Engine...")
207213
with builder.build_engine(network, config) as engine, open(args.output, "wb") as f:

0 commit comments

Comments
 (0)