1 #include <petsc/private/randomimpl.h> 2 #include <thrust/transform.h> 3 #include <thrust/device_ptr.h> 4 #include <thrust/iterator/counting_iterator.h> 5 6 #if defined(PETSC_USE_COMPLEX) 7 struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> { 8 PetscReal rl, rw; 9 PetscReal il, iw; 10 11 complexscalelw(PetscScalar low, PetscScalar width) 12 { 13 rl = PetscRealPart(low); 14 il = PetscImaginaryPart(low); 15 rw = PetscRealPart(width); 16 iw = PetscImaginaryPart(width); 17 } 18 19 __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; } 20 }; 21 #endif 22 23 struct realscalelw : public thrust::unary_function<PetscReal, PetscReal> { 24 PetscReal l, w; 25 26 realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { } 27 28 __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; } 29 }; 30 31 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg) 32 { 33 PetscFunctionBegin; 34 if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS); 35 if (isneg) { /* complex case, need to scale differently */ 36 #if defined(PETSC_USE_COMPLEX) 37 thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 38 auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0))); 39 thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width)); 40 #else 41 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n); 42 #endif 43 } else { 44 PetscReal rl = PetscRealPart(r->low); 45 PetscReal rw = PetscRealPart(r->width); 46 thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 47 thrust::transform(pval, pval + n, pval, realscalelw(rl, rw)); 48 } 49 PetscFunctionReturn(PETSC_SUCCESS); 50 } 51