@@ -20,150 +20,162 @@ namespace libcloudphxx
2020 class rng
2121 {
2222#if !defined(__NVCC__)
23- // serial version using C++11's <random>
24- using engine_t = std::mt19937;
23+ // serial version using C++11's <random>
24+ using engine_t = std::mt19937;
2525 using dist_u01_t = std::uniform_real_distribution<real_t >;
2626 using dist_normal01_t = std::normal_distribution<real_t >;
2727 using dist_un_t = std::uniform_int_distribution<unsigned int >;
28- engine_t engine;
29- dist_u01_t dist_u01;
30- dist_normal01_t dist_normal01;
31- dist_un_t dist_un;
28+ engine_t engine;
29+ dist_u01_t dist_u01;
30+ dist_normal01_t dist_normal01;
31+ dist_un_t dist_un;
3232
33- struct fnctr_u01
34- {
33+ struct fnctr_u01
34+ {
3535 engine_t &engine;
3636 dist_u01_t &dist_u01;
37- real_t operator ()() { return dist_u01 (engine); }
38- };
37+ real_t operator ()() { return dist_u01 (engine); }
38+ };
3939
40- struct fnctr_normal01
41- {
40+ struct fnctr_normal01
41+ {
4242 engine_t &engine;
4343 dist_normal01_t &dist_normal01;
44- real_t operator ()() { return dist_normal01 (engine); }
45- };
44+ real_t operator ()() { return dist_normal01 (engine); }
45+ };
4646
47- struct fnctr_un
48- {
47+ struct fnctr_un
48+ {
4949 engine_t &engine;
5050 dist_un_t &dist_un;
51- real_t operator ()() { return dist_un (engine); }
52- };
51+ real_t operator ()() { return dist_un (engine); }
52+ };
5353
54- public:
54+ public:
5555
5656 // ctor
5757 rng (int seed) : engine(seed), dist_u01(0 ,1 ), dist_normal01(0 ,1 ), dist_un(0 , std::numeric_limits<unsigned int >::max()) {}
5858
59- void generate_n (
60- thrust_device::vector<real_t > &u01,
61- const thrust_size_t n
62- ) {
59+ void reseed (int seed)
60+ {
61+ engine.seed (seed);
62+ }
63+
64+ void generate_n (
65+ thrust_device::vector<real_t > &u01,
66+ const thrust_size_t n
67+ ) {
6368 // note: generate_n copies the third argument!!!
64- std::generate_n (u01.begin (), n, fnctr_u01 ({engine, dist_u01}));
65- }
69+ std::generate_n (u01.begin (), n, fnctr_u01 ({engine, dist_u01}));
70+ }
6671
67- void generate_normal_n (
68- thrust_device::vector<real_t > &normal01,
69- const thrust_size_t n
70- ) {
72+ void generate_normal_n (
73+ thrust_device::vector<real_t > &normal01,
74+ const thrust_size_t n
75+ ) {
7176 // note: generate_n copies the third argument!!!
72- std::generate_n (normal01.begin (), n, fnctr_normal01 ({engine, dist_normal01}));
73- }
77+ std::generate_n (normal01.begin (), n, fnctr_normal01 ({engine, dist_normal01}));
78+ }
7479
75- void generate_n (
76- thrust_device::vector<unsigned int > &un,
77- const thrust_size_t n
78- ) {
80+ void generate_n (
81+ thrust_device::vector<unsigned int > &un,
82+ const thrust_size_t n
83+ ) {
7984 // note: generate_n copies the third argument!!!
80- std::generate_n (un.begin (), n, fnctr_un ({engine, dist_un}));
81- }
85+ std::generate_n (un.begin (), n, fnctr_un ({engine, dist_un}));
86+ }
8287#endif
8388 };
8489
8590 template <typename real_t >
8691 class rng <real_t , CUDA>
8792 {
8893#if defined(__NVCC__)
89- // CUDA parallel version using curand
94+ // CUDA parallel version using curand
9095
91- // private member fields
92- curandGenerator_t gen;
93-
94- public:
96+ // private member fields
97+ curandGenerator_t gen;
98+
99+ public:
95100
96- rng (int seed)
97- {
101+ rng (int seed)
102+ {
98103 {
99- int status = curandCreateGenerator (&gen, CURAND_RNG_PSEUDO_MTGP32);
100- assert (status == CURAND_STATUS_SUCCESS /* && "curandCreateGenerator failed"*/ );
101- _unused (status);
104+ int status = curandCreateGenerator (&gen, CURAND_RNG_PSEUDO_MTGP32);
105+ assert (status == CURAND_STATUS_SUCCESS /* && "curandCreateGenerator failed"*/ );
106+ _unused (status);
102107 }
103108 {
104- int status = curandSetPseudoRandomGeneratorSeed (gen, seed);
105- assert (status == CURAND_STATUS_SUCCESS /* && "curandSetPseudoRandomGeneratorSeed failed"*/ );
109+ int status = curandSetPseudoRandomGeneratorSeed (gen, seed);
110+ assert (status == CURAND_STATUS_SUCCESS /* && "curandSetPseudoRandomGeneratorSeed failed"*/ );
106111 _unused (status);
107- }
112+ }
113+ }
114+
115+ void reseed (int seed)
116+ {
117+ int status = curandSetPseudoRandomGeneratorSeed (gen, seed);
118+ assert (status == CURAND_STATUS_SUCCESS /* && "curandSetPseudoRandomGeneratorSeed failed"*/ );
119+ _unused (status);
108120 }
109121
110- ~rng ()
111- {
112- int status = curandDestroyGenerator (gen);
113- assert (status == CURAND_STATUS_SUCCESS /* && "curandDestroyGenerator failed"*/ );
122+ ~rng ()
123+ {
124+ int status = curandDestroyGenerator (gen);
125+ assert (status == CURAND_STATUS_SUCCESS /* && "curandDestroyGenerator failed"*/ );
114126 _unused (status);
115- }
116-
117- void generate_n (
118- thrust_device::vector<float > &v,
119- const thrust_size_t n
120- )
121- {
122- int status = curandGenerateUniform (gen, thrust::raw_pointer_cast (v.data ()), n);
127+ }
128+
129+ void generate_n (
130+ thrust_device::vector<float > &v,
131+ const thrust_size_t n
132+ )
133+ {
134+ int status = curandGenerateUniform (gen, thrust::raw_pointer_cast (v.data ()), n);
123135 assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
124136 _unused (status);
125137
126- }
138+ }
127139
128- void generate_n (
129- thrust_device::vector<double > &v,
130- const thrust_size_t n
131- )
132- {
133- int status = curandGenerateUniformDouble (gen, thrust::raw_pointer_cast (v.data ()), n);
134- assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
140+ void generate_n (
141+ thrust_device::vector<double > &v,
142+ const thrust_size_t n
143+ )
144+ {
145+ int status = curandGenerateUniformDouble (gen, thrust::raw_pointer_cast (v.data ()), n);
146+ assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
135147 _unused (status);
136- }
137-
138- void generate_normal_n (
139- thrust_device::vector<float > &v,
140- const thrust_size_t n
141- )
142- {
143- int status = curandGenerateNormal (gen, thrust::raw_pointer_cast (v.data ()), n, float (0 ), float (1 ));
144- assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
148+ }
149+
150+ void generate_normal_n (
151+ thrust_device::vector<float > &v,
152+ const thrust_size_t n
153+ )
154+ {
155+ int status = curandGenerateNormal (gen, thrust::raw_pointer_cast (v.data ()), n, float (0 ), float (1 ));
156+ assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
145157 _unused (status);
146- }
147-
148- void generate_normal_n (
149- thrust_device::vector<double > &v,
150- const thrust_size_t n
151- )
152- {
153- int status = curandGenerateNormalDouble (gen, thrust::raw_pointer_cast (v.data ()), n, double (0 ), double (1 ));
154- assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
158+ }
159+
160+ void generate_normal_n (
161+ thrust_device::vector<double > &v,
162+ const thrust_size_t n
163+ )
164+ {
165+ int status = curandGenerateNormalDouble (gen, thrust::raw_pointer_cast (v.data ()), n, double (0 ), double (1 ));
166+ assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
155167 _unused (status);
156- }
157-
158- void generate_n (
159- thrust_device::vector<unsigned int > &v,
160- const thrust_size_t n
161- )
162- {
163- int status = curandGenerate (gen, thrust::raw_pointer_cast (v.data ()), n);
164- assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
168+ }
169+
170+ void generate_n (
171+ thrust_device::vector<unsigned int > &v,
172+ const thrust_size_t n
173+ )
174+ {
175+ int status = curandGenerate (gen, thrust::raw_pointer_cast (v.data ()), n);
176+ assert (status == CURAND_STATUS_SUCCESS /* && "curandGenerateUniform failed"*/ );
165177 _unused (status);
166- }
178+ }
167179#endif
168180 };
169181 };
0 commit comments