1 #include <petsc/private/randomimpl.h> 2 #include <thrust/transform.h> 3 #include <thrust/device_ptr.h> 4 #include <thrust/iterator/counting_iterator.h> 5 6 #if defined(PETSC_USE_COMPLEX) 7 struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> { 8 PetscReal rl, rw; 9 PetscReal il, iw; 10 11 complexscalelw(PetscScalar low, PetscScalar width) { 12 rl = PetscRealPart(low); 13 il = PetscImaginaryPart(low); 14 rw = PetscRealPart(width); 15 iw = PetscImaginaryPart(width); 16 } 17 18 __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return x.get<1>() % 2 ? x.get<0>() * iw + il : x.get<0>() * rw + rl; } 19 }; 20 #endif 21 22 struct realscalelw : public thrust::unary_function<PetscReal, PetscReal> { 23 PetscReal l, w; 24 25 realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { } 26 27 __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; } 28 }; 29 30 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg) { 31 PetscFunctionBegin; 32 if (!r->iset) PetscFunctionReturn(0); 33 if (isneg) { /* complex case, need to scale differently */ 34 #if defined(PETSC_USE_COMPLEX) 35 thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 36 auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0))); 37 thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width)); 38 #else 39 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n); 40 #endif 41 } else { 42 PetscReal rl = PetscRealPart(r->low); 43 PetscReal rw = PetscRealPart(r->width); 44 thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 45 thrust::transform(pval, pval + n, pval, realscalelw(rl, rw)); 46 } 47 PetscFunctionReturn(0); 48 } 49