xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision cd871708d6ae82bd70cc1a9e2138f9b57839fe75)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <petscdevice_cuda.h>
4 
5 typedef struct {
6   curandGenerator_t gen;
7 } PetscRandom_CURAND;
8 
PetscRandomSeed_CURAND(PetscRandom r)9 static PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
10 {
11   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
12 
13   PetscFunctionBegin;
14   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
15   PetscFunctionReturn(PETSC_SUCCESS);
16 }
17 
18 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);
19 
PetscRandomGetValuesReal_CURAND(PetscRandom r,PetscInt n,PetscReal * val)20 static PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
21 {
22   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
23   size_t              nn     = n < 0 ? (size_t)(-2 * n) : (size_t)n; /* handle complex case */
24 
25   PetscFunctionBegin;
26 #if defined(PETSC_USE_REAL_SINGLE)
27   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
28 #else
29   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
30 #endif
31   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
32   PetscFunctionReturn(PETSC_SUCCESS);
33 }
34 
PetscRandomGetValues_CURAND(PetscRandom r,PetscInt n,PetscScalar * val)35 static PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
36 {
37   PetscFunctionBegin;
38 #if defined(PETSC_USE_COMPLEX)
39   /* pass negative size to flag complex scaling (if needed) */
40   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
41 #else
42   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
43 #endif
44   PetscFunctionReturn(PETSC_SUCCESS);
45 }
46 
PetscRandomDestroy_CURAND(PetscRandom r)47 static PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
48 {
49   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
50 
51   PetscFunctionBegin;
52   PetscCallCURAND(curandDestroyGenerator(curand->gen));
53   PetscCall(PetscFree(r->data));
54   PetscFunctionReturn(PETSC_SUCCESS);
55 }
56 
57 static struct _PetscRandomOps PetscRandomOps_Values = {
58   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
59   PetscDesignatedInitializer(getvalue, NULL),
60   PetscDesignatedInitializer(getvaluereal, NULL),
61   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
62   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
63   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
64 };
65 
66 /*MC
67    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object
68 
69   Level: beginner
70 
71   Note:
72   This random number generator is available when PETSc is configured with ``./configure --with-cuda=1``
73 
74 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
75 M*/
76 
PetscRandomCreate_CURAND(PetscRandom r)77 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
78 {
79   PetscRandom_CURAND *curand;
80   PetscDeviceContext  dctx;
81   cudaStream_t       *stream;
82 
83   PetscFunctionBegin;
84   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
85   PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUDA));
86   PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
87   PetscCall(PetscNew(&curand));
88   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
89   PetscCallCURAND(curandSetStream(curand->gen, *stream));
90   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
91   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
92   r->ops[0] = PetscRandomOps_Values;
93   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
94   r->data = curand;
95   PetscCall(PetscRandomSeed_CURAND(r));
96   PetscFunctionReturn(PETSC_SUCCESS);
97 }
98