1 #include <petsc/private/deviceimpl.h> 2 #include <petsc/private/randomimpl.h> 3 #include <petscdevice_cuda.h> 4 5 typedef struct { 6 curandGenerator_t gen; 7 } PetscRandom_CURAND; 8 9 static PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r) 10 { 11 PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data; 12 13 PetscFunctionBegin; 14 PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed)); 15 PetscFunctionReturn(PETSC_SUCCESS); 16 } 17 18 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool); 19 20 static PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val) 21 { 22 PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data; 23 size_t nn = n < 0 ? (size_t)(-2 * n) : (size_t)n; /* handle complex case */ 24 25 PetscFunctionBegin; 26 #if defined(PETSC_USE_REAL_SINGLE) 27 PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn)); 28 #else 29 PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn)); 30 #endif 31 if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0))); 32 PetscFunctionReturn(PETSC_SUCCESS); 33 } 34 35 static PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val) 36 { 37 PetscFunctionBegin; 38 #if defined(PETSC_USE_COMPLEX) 39 /* pass negative size to flag complex scaling (if needed) */ 40 PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val)); 41 #else 42 PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val)); 43 #endif 44 PetscFunctionReturn(PETSC_SUCCESS); 45 } 46 47 static PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r) 48 { 49 PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data; 50 51 PetscFunctionBegin; 52 PetscCallCURAND(curandDestroyGenerator(curand->gen)); 53 PetscCall(PetscFree(r->data)); 54 PetscFunctionReturn(PETSC_SUCCESS); 55 } 56 57 static struct _PetscRandomOps PetscRandomOps_Values = { 58 PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND), 59 PetscDesignatedInitializer(getvalue, NULL), 60 PetscDesignatedInitializer(getvaluereal, NULL), 61 PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND), 62 PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND), 63 PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND), 64 }; 65 66 /*MC 67 PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object 68 69 Level: beginner 70 71 Note: 72 This random number generator is available when PETSc is configured with ``./configure --with-cuda=1`` 73 74 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType` 75 M*/ 76 77 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r) 78 { 79 PetscRandom_CURAND *curand; 80 PetscDeviceContext dctx; 81 cudaStream_t *stream; 82 83 PetscFunctionBegin; 84 PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); 85 PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUDA)); 86 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 87 PetscCall(PetscNew(&curand)); 88 PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT)); 89 PetscCallCURAND(curandSetStream(curand->gen, *stream)); 90 /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */ 91 PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED)); 92 r->ops[0] = PetscRandomOps_Values; 93 PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND)); 94 r->data = curand; 95 PetscCall(PetscRandomSeed_CURAND(r)); 96 PetscFunctionReturn(PETSC_SUCCESS); 97 } 98