Skip to content
This repository was archived by the owner on Mar 14, 2025. It is now read-only.

Commit 493aa38

Browse files
committed
add random seed to args and to generating random inputs
1 parent 7cbd8a0 commit 493aa38

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

inference/infer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,13 @@ def get_random_inputs(
9393
engine: trt.ICudaEngine,
9494
context: trt.IExecutionContext,
9595
input_binding_idxs: List[int],
96+
seed: int = 42,
9697
):
9798
# Input data for inference
9899
host_inputs = []
99100
print("Generating Random Inputs")
101+
print("\tUsing random seed: {}".format(seed))
102+
np.random.seed(seed)
100103
for binding_index in input_binding_idxs:
101104
# If input shape is fixed, we'll just use it
102105
input_shape = context.get_binding_shape(binding_index)
@@ -119,9 +122,10 @@ def get_random_inputs(
119122

120123
def main():
121124
parser = argparse.ArgumentParser()
122-
parser.add_argument(
123-
"-e", "--engine", required=True, type=str, help="Path to TensorRT engine file."
124-
)
125+
parser.add_argument("-e", "--engine", required=True, type=str,
126+
help="Path to TensorRT engine file.")
127+
parser.add_argument("-s", "--seed", type=int,
128+
help="Random seed for reproducibility.")
125129
args = parser.parse_args()
126130

127131
# Load a serialized engine into memory
@@ -142,7 +146,7 @@ def main():
142146
input_names = [engine.get_binding_name(binding_idx) for binding_idx in input_binding_idxs]
143147

144148
# Generate random inputs based on profile shapes
145-
host_inputs = get_random_inputs(engine, context, input_binding_idxs)
149+
host_inputs = get_random_inputs(engine, context, input_binding_idxs, seed=args.seed)
146150

147151
# Allocate device memory for inputs. This can be easily re-used if the
148152
# input shapes don't change

0 commit comments

Comments
 (0)