xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision d6685f554fbda8d96c6a5d73ab0e7a4e21a05c51)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <curand.h>
4 
5 typedef struct {
6   curandGenerator_t gen;
7 } PetscRandom_CURAND;
8 
9 PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
10 {
11   curandStatus_t     cerr;
12   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
13 
14   PetscFunctionBegin;
15   cerr = curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed);CHKERRCURAND(cerr);
16   PetscFunctionReturn(0);
17 }
18 
19 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);
20 
21 PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
22 {
23   curandStatus_t     cerr;
24   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
25   size_t             nn = n < 0 ? (size_t)(-2*n) : n; /* handle complex case */
26 
27   PetscFunctionBegin;
28 #if defined(PETSC_USE_REAL_SINGLE)
29   cerr = curandGenerateUniform(curand->gen,val,nn);CHKERRCURAND(cerr);
30 #else
31   cerr = curandGenerateUniformDouble(curand->gen,val,nn);CHKERRCURAND(cerr);
32 #endif
33   if (r->iset) {
34     PetscErrorCode ierr = PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0));CHKERRQ(ierr);
35   }
36   PetscFunctionReturn(0);
37 }
38 
39 PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
40 {
41   PetscErrorCode ierr;
42 
43   PetscFunctionBegin;
44 #if defined(PETSC_USE_COMPLEX)
45   /* pass negative size to flag complex scaling (if needed) */
46   ierr = PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val);CHKERRQ(ierr);
47 #else
48   ierr = PetscRandomGetValuesReal_CURAND(r,n,val);CHKERRQ(ierr);
49 #endif
50   PetscFunctionReturn(0);
51 }
52 
53 PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
54 {
55   PetscErrorCode     ierr;
56   curandStatus_t     cerr;
57   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
58 
59   PetscFunctionBegin;
60   cerr = curandDestroyGenerator(curand->gen);CHKERRCURAND(cerr);
61   ierr = PetscFree(r->data);CHKERRQ(ierr);
62   PetscFunctionReturn(0);
63 }
64 
65 static struct _PetscRandomOps PetscRandomOps_Values = {
66   PetscDesignatedInitializer(seed,PetscRandomSeed_CURAND),
67   PetscDesignatedInitializer(getvalue,NULL),
68   PetscDesignatedInitializer(getvaluereal,NULL),
69   PetscDesignatedInitializer(getvalues,PetscRandomGetValues_CURAND),
70   PetscDesignatedInitializer(getvaluesreal,PetscRandomGetValuesReal_CURAND),
71   PetscDesignatedInitializer(destroy,PetscRandomDestroy_CURAND),
72 };
73 
74 /*MC
75    PETSCCURAND - access to the CUDA random number generator
76 
77   Level: beginner
78 
79 .seealso: PetscRandomCreate(), PetscRandomSetType()
80 M*/
81 
82 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
83 {
84   PetscErrorCode     ierr;
85   curandStatus_t     cerr;
86   PetscRandom_CURAND *curand;
87 
88   PetscFunctionBegin;
89   ierr = PetscDeviceInitialize(PETSC_DEVICE_CUDA);CHKERRQ(ierr);
90   ierr = PetscNewLog(r,&curand);CHKERRQ(ierr);
91   cerr = curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT);CHKERRCURAND(cerr);
92   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
93   cerr = curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED);CHKERRCURAND(cerr);
94   ierr = PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values));CHKERRQ(ierr);
95   ierr = PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND);CHKERRQ(ierr);
96   r->data = curand;
97   r->seed = 1234ULL; /* taken from example */
98   ierr = PetscRandomSeed_CURAND(r);CHKERRQ(ierr);
99   PetscFunctionReturn(0);
100 }
101