xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision cd7e8a5e83fb5f23fbe440fa5d826b1569053b5a)
1 #include <petsc/private/deviceimpl.h>
2 #include <petsc/private/randomimpl.h>
3 #include <curand.h>
4 
5 typedef struct {
6   curandGenerator_t gen;
7 } PetscRandom_CURAND;
8 
9 PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
10 {
11   curandStatus_t     cerr;
12   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
13 
14   PetscFunctionBegin;
15   cerr = curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed);CHKERRCURAND(cerr);
16   PetscFunctionReturn(0);
17 }
18 
19 PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);
20 
21 PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
22 {
23   curandStatus_t     cerr;
24   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
25   size_t             nn = n < 0 ? (size_t)(-2*n) : n; /* handle complex case */
26 
27   PetscFunctionBegin;
28 #if defined(PETSC_USE_REAL_SINGLE)
29   cerr = curandGenerateUniform(curand->gen,val,nn);CHKERRCURAND(cerr);
30 #else
31   cerr = curandGenerateUniformDouble(curand->gen,val,nn);CHKERRCURAND(cerr);
32 #endif
33   if (r->iset) {
34     PetscErrorCode ierr = PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0));CHKERRQ(ierr);
35   }
36   PetscFunctionReturn(0);
37 }
38 
39 PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
40 {
41   PetscErrorCode ierr;
42 
43   PetscFunctionBegin;
44 #if defined(PETSC_USE_COMPLEX)
45   /* pass negative size to flag complex scaling (if needed) */
46   ierr = PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val);CHKERRQ(ierr);
47 #else
48   ierr = PetscRandomGetValuesReal_CURAND(r,n,val);CHKERRQ(ierr);
49 #endif
50   PetscFunctionReturn(0);
51 }
52 
53 PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
54 {
55   PetscErrorCode     ierr;
56   curandStatus_t     cerr;
57   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
58 
59   PetscFunctionBegin;
60   cerr = curandDestroyGenerator(curand->gen);CHKERRCURAND(cerr);
61   ierr = PetscFree(r->data);CHKERRQ(ierr);
62   PetscFunctionReturn(0);
63 }
64 
65 static struct _PetscRandomOps PetscRandomOps_Values = {
66   PetscRandomSeed_CURAND,
67   NULL,
68   NULL,
69   PetscRandomGetValues_CURAND,
70   PetscRandomGetValuesReal_CURAND,
71   PetscRandomDestroy_CURAND,
72   NULL
73 };
74 
75 /*MC
76    PETSCCURAND - access to the CUDA random number generator
77 
78   Level: beginner
79 
80 .seealso: PetscRandomCreate(), PetscRandomSetType()
81 M*/
82 
83 PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
84 {
85   PetscErrorCode     ierr;
86   curandStatus_t     cerr;
87   PetscRandom_CURAND *curand;
88 
89   PetscFunctionBegin;
90   ierr = PetscDeviceInitialize(PETSC_DEVICE_CUDA);CHKERRQ(ierr);
91   ierr = PetscNewLog(r,&curand);CHKERRQ(ierr);
92   cerr = curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT);CHKERRCURAND(cerr);
93   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
94   cerr = curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED);CHKERRCURAND(cerr);
95   ierr = PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values));CHKERRQ(ierr);
96   ierr = PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND);CHKERRQ(ierr);
97   r->data = curand;
98   r->seed = 1234ULL; /* taken from example */
99   ierr = PetscRandomSeed_CURAND(r);CHKERRQ(ierr);
100   PetscFunctionReturn(0);
101 }
102