@@ -93,10 +93,13 @@ def get_random_inputs(
9393 engine : trt .ICudaEngine ,
9494 context : trt .IExecutionContext ,
9595 input_binding_idxs : List [int ],
96+ seed : int = 42 ,
9697):
9798 # Input data for inference
9899 host_inputs = []
99100 print ("Generating Random Inputs" )
101+ print ("\t Using random seed: {}" .format (seed ))
102+ np .random .seed (seed )
100103 for binding_index in input_binding_idxs :
101104 # If input shape is fixed, we'll just use it
102105 input_shape = context .get_binding_shape (binding_index )
@@ -119,9 +122,10 @@ def get_random_inputs(
119122
120123def main ():
121124 parser = argparse .ArgumentParser ()
122- parser .add_argument (
123- "-e" , "--engine" , required = True , type = str , help = "Path to TensorRT engine file."
124- )
125+ parser .add_argument ("-e" , "--engine" , required = True , type = str ,
126+ help = "Path to TensorRT engine file." )
127+ parser .add_argument ("-s" , "--seed" , type = int ,
128+ help = "Random seed for reproducibility." )
125129 args = parser .parse_args ()
126130
127131 # Load a serialized engine into memory
@@ -142,7 +146,7 @@ def main():
142146 input_names = [engine .get_binding_name (binding_idx ) for binding_idx in input_binding_idxs ]
143147
144148 # Generate random inputs based on profile shapes
145- host_inputs = get_random_inputs (engine , context , input_binding_idxs )
149+ host_inputs = get_random_inputs (engine , context , input_binding_idxs , seed = args . seed )
146150
147151 # Allocate device memory for inputs. This can be easily re-used if the
148152 # input shapes don't change
0 commit comments