1 #include <petsc/private/randomimpl.h> 2 #include <thrust/transform.h> 3 #include <thrust/device_ptr.h> 4 #include <thrust/iterator/counting_iterator.h> 5 6 #if defined(PETSC_USE_COMPLEX) 7 struct complexscalelw 8 #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) 9 : 10 public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> 11 #endif 12 { 13 PetscReal rl, rw; 14 PetscReal il, iw; 15 16 complexscalelw(PetscScalar low, PetscScalar width) 17 { 18 rl = PetscRealPart(low); 19 il = PetscImaginaryPart(low); 20 rw = PetscRealPart(width); 21 iw = PetscImaginaryPart(width); 22 } 23 24 __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; } 25 }; 26 #endif 27 28 struct realscalelw 29 #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) // To suppress the warning "thrust::THRUST_200700_860_NS::unary_function is deprecated" 30 : 31 public thrust::unary_function<PetscReal, PetscReal> 32 #endif 33 { 34 PetscReal l, w; 35 36 realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { } 37 38 __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; } 39 }; 40 41 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg) 42 { 43 PetscFunctionBegin; 44 if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS); 45 if (isneg) { /* complex case, need to scale differently */ 46 #if defined(PETSC_USE_COMPLEX) 47 thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 48 auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0))); 49 thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width)); 50 #else 51 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n); 52 #endif 53 } else { 54 PetscReal rl = PetscRealPart(r->low); 55 PetscReal rw = PetscRealPart(r->width); 56 thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 57 thrust::transform(pval, pval + n, pval, realscalelw(rl, rw)); 58 } 59 PetscFunctionReturn(PETSC_SUCCESS); 60 } 61