xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision f97672e55eacc8688507b9471cd7ec2664d7f203)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <curand.h>
4 
5 typedef struct {
6   curandGenerator_t gen;
7 } PetscRandom_CURAND;
8 
9 PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
10 {
11   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
12 
13   PetscFunctionBegin;
14   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed));
15   PetscFunctionReturn(0);
16 }
17 
18 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);
19 
20 PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
21 {
22   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
23   size_t             nn = n < 0 ? (size_t)(-2*n) : n; /* handle complex case */
24 
25   PetscFunctionBegin;
26 #if defined(PETSC_USE_REAL_SINGLE)
27   PetscCallCURAND(curandGenerateUniform(curand->gen,val,nn));
28 #else
29   PetscCallCURAND(curandGenerateUniformDouble(curand->gen,val,nn));
30 #endif
31   if (r->iset) {
32     PetscCall(PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0)));
33   }
34   PetscFunctionReturn(0);
35 }
36 
37 PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
38 {
39   PetscFunctionBegin;
40 #if defined(PETSC_USE_COMPLEX)
41   /* pass negative size to flag complex scaling (if needed) */
42   PetscCall(PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val));
43 #else
44   PetscCall(PetscRandomGetValuesReal_CURAND(r,n,val));
45 #endif
46   PetscFunctionReturn(0);
47 }
48 
49 PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
50 {
51   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
52 
53   PetscFunctionBegin;
54   PetscCallCURAND(curandDestroyGenerator(curand->gen));
55   PetscCall(PetscFree(r->data));
56   PetscFunctionReturn(0);
57 }
58 
59 static struct _PetscRandomOps PetscRandomOps_Values = {
60   PetscDesignatedInitializer(seed,PetscRandomSeed_CURAND),
61   PetscDesignatedInitializer(getvalue,NULL),
62   PetscDesignatedInitializer(getvaluereal,NULL),
63   PetscDesignatedInitializer(getvalues,PetscRandomGetValues_CURAND),
64   PetscDesignatedInitializer(getvaluesreal,PetscRandomGetValuesReal_CURAND),
65   PetscDesignatedInitializer(destroy,PetscRandomDestroy_CURAND),
66 };
67 
68 /*MC
69    PETSCCURAND - access to the CUDA random number generator
70 
71   Level: beginner
72 
73 .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`
74 M*/
75 
76 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
77 {
78   PetscRandom_CURAND *curand;
79 
80   PetscFunctionBegin;
81   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
82   PetscCall(PetscNewLog(r,&curand));
83   PetscCallCURAND(curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT));
84   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
85   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED));
86   PetscCall(PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values)));
87   PetscCall(PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND));
88   r->data = curand;
89   r->seed = 1234ULL; /* taken from example */
90   PetscCall(PetscRandomSeed_CURAND(r));
91   PetscFunctionReturn(0);
92 }
93