xref: /petsc/src/sys/classes/random/impls/curand/curand2.cu (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <petsc/private/randomimpl.h>
2 #include <thrust/transform.h>
3 #include <thrust/device_ptr.h>
4 #include <thrust/iterator/counting_iterator.h>
5 
6 #if defined(PETSC_USE_COMPLEX)
7 struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> {
8   PetscReal rl, rw;
9   PetscReal il, iw;
10 
11   complexscalelw(PetscScalar low, PetscScalar width) {
12     rl = PetscRealPart(low);
13     il = PetscImaginaryPart(low);
14     rw = PetscRealPart(width);
15     iw = PetscImaginaryPart(width);
16   }
17 
18   __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return x.get<1>() % 2 ? x.get<0>() * iw + il : x.get<0>() * rw + rl; }
19 };
20 #endif
21 
22 struct realscalelw : public thrust::unary_function<PetscReal, PetscReal> {
23   PetscReal l, w;
24 
25   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }
26 
27   __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
28 };
29 
30 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg) {
31   PetscFunctionBegin;
32   if (!r->iset) PetscFunctionReturn(0);
33   if (isneg) { /* complex case, need to scale differently */
34 #if defined(PETSC_USE_COMPLEX)
35     thrust::device_ptr<PetscReal> pval  = thrust::device_pointer_cast(val);
36     auto                          zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
37     thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
38 #else
39     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
40 #endif
41   } else {
42     PetscReal                     rl   = PetscRealPart(r->low);
43     PetscReal                     rw   = PetscRealPart(r->width);
44     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
45     thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
46   }
47   PetscFunctionReturn(0);
48 }
49