xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision 1b37a2a7cc4a4fb30c3e967db1c694c0a1013f51)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <petscdevice_cuda.h>
4 #include <curand.h>
5 
6 typedef struct {
7   curandGenerator_t gen;
8 } PetscRandom_CURAND;
9 
10 static PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
11 {
12   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
13 
14   PetscFunctionBegin;
15   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
16   PetscFunctionReturn(PETSC_SUCCESS);
17 }
18 
19 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);
20 
21 static PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
22 {
23   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
24   size_t              nn     = n < 0 ? (size_t)(-2 * n) : (size_t)n; /* handle complex case */
25 
26   PetscFunctionBegin;
27 #if defined(PETSC_USE_REAL_SINGLE)
28   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
29 #else
30   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
31 #endif
32   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
33   PetscFunctionReturn(PETSC_SUCCESS);
34 }
35 
36 static PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
37 {
38   PetscFunctionBegin;
39 #if defined(PETSC_USE_COMPLEX)
40   /* pass negative size to flag complex scaling (if needed) */
41   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
42 #else
43   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
44 #endif
45   PetscFunctionReturn(PETSC_SUCCESS);
46 }
47 
48 static PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
49 {
50   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
51 
52   PetscFunctionBegin;
53   PetscCallCURAND(curandDestroyGenerator(curand->gen));
54   PetscCall(PetscFree(r->data));
55   PetscFunctionReturn(PETSC_SUCCESS);
56 }
57 
58 static struct _PetscRandomOps PetscRandomOps_Values = {
59   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
60   PetscDesignatedInitializer(getvalue, NULL),
61   PetscDesignatedInitializer(getvaluereal, NULL),
62   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
63   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
64   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
65 };
66 
67 /*MC
68    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object
69 
70   Level: beginner
71 
72   Note:
73   This random number generator is available when PETSc is configured with ``./configure --with-cuda=1``
74 
75 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
76 M*/
77 
78 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
79 {
80   PetscRandom_CURAND *curand;
81 
82   PetscFunctionBegin;
83   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
84   PetscCall(PetscNew(&curand));
85   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
86   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
87   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
88   r->ops[0] = PetscRandomOps_Values;
89   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
90   r->data = curand;
91   PetscCall(PetscRandomSeed_CURAND(r));
92   PetscFunctionReturn(PETSC_SUCCESS);
93 }
94