xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision d71ae5a4db6382e7f06317b8d368875286fe9008)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <petscdevice_cuda.h>
4 #include <curand.h>
5 
6 typedef struct {
7   curandGenerator_t gen;
8 } PetscRandom_CURAND;
9 
10 PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
11 {
12   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
13 
14   PetscFunctionBegin;
15   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
16   PetscFunctionReturn(0);
17 }
18 
19 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);
20 
21 PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
22 {
23   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
24   size_t              nn     = n < 0 ? (size_t)(-2 * n) : n; /* handle complex case */
25 
26   PetscFunctionBegin;
27 #if defined(PETSC_USE_REAL_SINGLE)
28   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
29 #else
30   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
31 #endif
32   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
33   PetscFunctionReturn(0);
34 }
35 
36 PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
37 {
38   PetscFunctionBegin;
39 #if defined(PETSC_USE_COMPLEX)
40   /* pass negative size to flag complex scaling (if needed) */
41   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
42 #else
43   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
44 #endif
45   PetscFunctionReturn(0);
46 }
47 
48 PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
49 {
50   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
51 
52   PetscFunctionBegin;
53   PetscCallCURAND(curandDestroyGenerator(curand->gen));
54   PetscCall(PetscFree(r->data));
55   PetscFunctionReturn(0);
56 }
57 
58 static struct _PetscRandomOps PetscRandomOps_Values = {
59   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
60   PetscDesignatedInitializer(getvalue, NULL),
61   PetscDesignatedInitializer(getvaluereal, NULL),
62   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
63   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
64   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
65 };
66 
67 /*MC
68    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object
69 
70   PETSc must be ./configure with the option --with-cuda to use this random number generator.
71 
72   Level: beginner
73 
74 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
75 M*/
76 
77 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
78 {
79   PetscRandom_CURAND *curand;
80 
81   PetscFunctionBegin;
82   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
83   PetscCall(PetscNew(&curand));
84   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
85   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
86   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
87   PetscCall(PetscMemcpy(r->ops, &PetscRandomOps_Values, sizeof(PetscRandomOps_Values)));
88   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
89   r->data = curand;
90   r->seed = 1234ULL; /* taken from example */
91   PetscCall(PetscRandomSeed_CURAND(r));
92   PetscFunctionReturn(0);
93 }
94