xref: /petsc/src/sys/classes/random/impls/curand/curand2.cu (revision df4cd43f92eaa320656440c40edb1046daee8f75)
1 #include <petsc/private/randomimpl.h>
2 #include <thrust/transform.h>
3 #include <thrust/device_ptr.h>
4 #include <thrust/iterator/counting_iterator.h>
5 
6 #if defined(PETSC_USE_COMPLEX)
7 struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> {
8   PetscReal rl, rw;
9   PetscReal il, iw;
10 
11   complexscalelw(PetscScalar low, PetscScalar width)
12   {
13     rl = PetscRealPart(low);
14     il = PetscImaginaryPart(low);
15     rw = PetscRealPart(width);
16     iw = PetscImaginaryPart(width);
17   }
18 
19   __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return x.get<1>() % 2 ? x.get<0>() * iw + il : x.get<0>() * rw + rl; }
20 };
21 #endif
22 
23 struct realscalelw : public thrust::unary_function<PetscReal, PetscReal> {
24   PetscReal l, w;
25 
26   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }
27 
28   __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
29 };
30 
31 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
32 {
33   PetscFunctionBegin;
34   if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS);
35   if (isneg) { /* complex case, need to scale differently */
36 #if defined(PETSC_USE_COMPLEX)
37     thrust::device_ptr<PetscReal> pval  = thrust::device_pointer_cast(val);
38     auto                          zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
39     thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
40 #else
41     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
42 #endif
43   } else {
44     PetscReal                     rl   = PetscRealPart(r->low);
45     PetscReal                     rw   = PetscRealPart(r->width);
46     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
47     thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
48   }
49   PetscFunctionReturn(PETSC_SUCCESS);
50 }
51