xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision e0f5bfbec699682fa3e8b8532b1176849ea4e12a)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <petscdevice_cuda.h>
4 #include <curand.h>
5 
6 typedef struct {
7   curandGenerator_t gen;
8 } PetscRandom_CURAND;
9 
10 PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r) {
11   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
12 
13   PetscFunctionBegin;
14   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
15   PetscFunctionReturn(0);
16 }
17 
18 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);
19 
20 PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val) {
21   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
22   size_t              nn     = n < 0 ? (size_t)(-2 * n) : n; /* handle complex case */
23 
24   PetscFunctionBegin;
25 #if defined(PETSC_USE_REAL_SINGLE)
26   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
27 #else
28   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
29 #endif
30   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
31   PetscFunctionReturn(0);
32 }
33 
34 PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val) {
35   PetscFunctionBegin;
36 #if defined(PETSC_USE_COMPLEX)
37   /* pass negative size to flag complex scaling (if needed) */
38   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
39 #else
40   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
41 #endif
42   PetscFunctionReturn(0);
43 }
44 
45 PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r) {
46   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
47 
48   PetscFunctionBegin;
49   PetscCallCURAND(curandDestroyGenerator(curand->gen));
50   PetscCall(PetscFree(r->data));
51   PetscFunctionReturn(0);
52 }
53 
54 static struct _PetscRandomOps PetscRandomOps_Values = {
55   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
56   PetscDesignatedInitializer(getvalue, NULL),
57   PetscDesignatedInitializer(getvaluereal, NULL),
58   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
59   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
60   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
61 };
62 
63 /*MC
64    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object
65 
66   PETSc must be ./configure with the option --with-cuda to use this random number generator.
67 
68   Level: beginner
69 
70 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
71 M*/
72 
73 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r) {
74   PetscRandom_CURAND *curand;
75 
76   PetscFunctionBegin;
77   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
78   PetscCall(PetscNew(&curand));
79   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
80   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
81   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
82   PetscCall(PetscMemcpy(r->ops, &PetscRandomOps_Values, sizeof(PetscRandomOps_Values)));
83   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
84   r->data = curand;
85   r->seed = 1234ULL; /* taken from example */
86   PetscCall(PetscRandomSeed_CURAND(r));
87   PetscFunctionReturn(0);
88 }
89