xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <curand.h>
4 
5 typedef struct {
6   curandGenerator_t gen;
7 } PetscRandom_CURAND;
8 
9 PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r) {
10   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
11 
12   PetscFunctionBegin;
13   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
14   PetscFunctionReturn(0);
15 }
16 
17 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);
18 
19 PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val) {
20   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
21   size_t              nn     = n < 0 ? (size_t)(-2 * n) : n; /* handle complex case */
22 
23   PetscFunctionBegin;
24 #if defined(PETSC_USE_REAL_SINGLE)
25   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
26 #else
27   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
28 #endif
29   if (r->iset) { PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0))); }
30   PetscFunctionReturn(0);
31 }
32 
33 PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val) {
34   PetscFunctionBegin;
35 #if defined(PETSC_USE_COMPLEX)
36   /* pass negative size to flag complex scaling (if needed) */
37   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
38 #else
39   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
40 #endif
41   PetscFunctionReturn(0);
42 }
43 
44 PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r) {
45   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
46 
47   PetscFunctionBegin;
48   PetscCallCURAND(curandDestroyGenerator(curand->gen));
49   PetscCall(PetscFree(r->data));
50   PetscFunctionReturn(0);
51 }
52 
53 static struct _PetscRandomOps PetscRandomOps_Values = {
54   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
55   PetscDesignatedInitializer(getvalue, NULL),
56   PetscDesignatedInitializer(getvaluereal, NULL),
57   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
58   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
59   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
60 };
61 
62 /*MC
63    PETSCCURAND - access to the CUDA random number generator
64 
65   PETSc must be ./configure with the option --with-cuda to use this random number generator.
66 
67   Level: beginner
68 
69 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`
70 M*/
71 
72 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r) {
73   PetscRandom_CURAND *curand;
74 
75   PetscFunctionBegin;
76   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
77   PetscCall(PetscNewLog(r, &curand));
78   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
79   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
80   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
81   PetscCall(PetscMemcpy(r->ops, &PetscRandomOps_Values, sizeof(PetscRandomOps_Values)));
82   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
83   r->data = curand;
84   r->seed = 1234ULL; /* taken from example */
85   PetscCall(PetscRandomSeed_CURAND(r));
86   PetscFunctionReturn(0);
87 }
88