xref: /petsc/src/sys/classes/random/impls/curand/curand2.cu (revision 586b72174521427559d8855856d614fb70566f6f)
1 #include <petsc/private/randomimpl.h>
2 #include <thrust/transform.h>
3 #include <thrust/device_ptr.h>
4 #include <thrust/iterator/counting_iterator.h>
5 
6 #if defined(PETSC_USE_COMPLEX)
7 struct complexscalelw
8   #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0)
9   :
10   public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal>
11   #endif
12 {
13   PetscReal rl, rw;
14   PetscReal il, iw;
15 
complexscalelwcomplexscalelw16   complexscalelw(PetscScalar low, PetscScalar width)
17   {
18     rl = PetscRealPart(low);
19     il = PetscImaginaryPart(low);
20     rw = PetscRealPart(width);
21     iw = PetscImaginaryPart(width);
22   }
23 
operator ()complexscalelw24   __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; }
25 };
26 #endif
27 
28 struct realscalelw
29 #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) // To suppress the warning "thrust::THRUST_200700_860_NS::unary_function is deprecated"
30   :
31   public thrust::unary_function<PetscReal, PetscReal>
32 #endif
33 {
34   PetscReal l, w;
35 
realscalelwrealscalelw36   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }
37 
operator ()realscalelw38   __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
39 };
40 
PetscRandomCurandScale_Private(PetscRandom r,size_t n,PetscReal * val,PetscBool isneg)41 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
42 {
43   PetscFunctionBegin;
44   if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS);
45   if (isneg) { /* complex case, need to scale differently */
46 #if defined(PETSC_USE_COMPLEX)
47     thrust::device_ptr<PetscReal> pval  = thrust::device_pointer_cast(val);
48     auto                          zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
49     thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
50 #else
51     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
52 #endif
53   } else {
54     PetscReal                     rl   = PetscRealPart(r->low);
55     PetscReal                     rw   = PetscRealPart(r->width);
56     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
57     thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
58   }
59   PetscFunctionReturn(PETSC_SUCCESS);
60 }
61