xref: /petsc/src/mat/impls/cufft/cufft.cu (revision 66af8762ec03dbef0e079729eb2a1734a35ed7ff)
1 /*
2     Provides an interface to the CUFFT package.
3     Testing examples can be found in ~src/mat/tests
4 */
5 
6 #include <petscdevice_cuda.h>
7 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
8 
9 typedef struct {
10   PetscInt      ndim;
11   PetscInt     *dim;
12   cufftHandle   p_forward, p_backward;
13   cufftComplex *devArray;
14 } Mat_CUFFT;
15 
16 static PetscErrorCode MatMult_SeqCUFFT(Mat A, Vec x, Vec y)
17 {
18   Mat_CUFFT    *cufft    = (Mat_CUFFT *)A->data;
19   cufftComplex *devArray = cufft->devArray;
20   PetscInt      ndim = cufft->ndim, *dim = cufft->dim;
21   PetscScalar  *x_array, *y_array;
22 
23   PetscFunctionBegin;
24   PetscCall(VecGetArray(x, &x_array));
25   PetscCall(VecGetArray(y, &y_array));
26   if (!cufft->p_forward) {
27     /* create a plan, then execute it */
28     switch (ndim) {
29     case 1:
30       PetscCallCUFFT(cufftPlan1d(&cufft->p_forward, dim[0], CUFFT_C2C, 1));
31       break;
32     case 2:
33       PetscCallCUFFT(cufftPlan2d(&cufft->p_forward, dim[0], dim[1], CUFFT_C2C));
34       break;
35     case 3:
36       PetscCallCUFFT(cufftPlan3d(&cufft->p_forward, dim[0], dim[1], dim[2], CUFFT_C2C));
37       break;
38     default:
39       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %" PetscInt_FMT "-dimensional transform", ndim);
40     }
41   }
42   /* transfer to GPU memory */
43   PetscCallCUDA(cudaMemcpy(devArray, x_array, sizeof(cufftComplex) * dim[ndim], cudaMemcpyHostToDevice));
44   /* execute transform */
45   PetscCallCUFFT(cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_FORWARD));
46   /* transfer from GPU memory */
47   PetscCallCUDA(cudaMemcpy(y_array, devArray, sizeof(cufftComplex) * dim[ndim], cudaMemcpyDeviceToHost));
48   PetscCall(VecRestoreArray(y, &y_array));
49   PetscCall(VecRestoreArray(x, &x_array));
50   PetscFunctionReturn(PETSC_SUCCESS);
51 }
52 
53 static PetscErrorCode MatMultTranspose_SeqCUFFT(Mat A, Vec x, Vec y)
54 {
55   Mat_CUFFT    *cufft    = (Mat_CUFFT *)A->data;
56   cufftComplex *devArray = cufft->devArray;
57   PetscInt      ndim = cufft->ndim, *dim = cufft->dim;
58   PetscScalar  *x_array, *y_array;
59 
60   PetscFunctionBegin;
61   PetscCall(VecGetArray(x, &x_array));
62   PetscCall(VecGetArray(y, &y_array));
63   if (!cufft->p_backward) {
64     /* create a plan, then execute it */
65     switch (ndim) {
66     case 1:
67       PetscCallCUFFT(cufftPlan1d(&cufft->p_backward, dim[0], CUFFT_C2C, 1));
68       break;
69     case 2:
70       PetscCallCUFFT(cufftPlan2d(&cufft->p_backward, dim[0], dim[1], CUFFT_C2C));
71       break;
72     case 3:
73       PetscCallCUFFT(cufftPlan3d(&cufft->p_backward, dim[0], dim[1], dim[2], CUFFT_C2C));
74       break;
75     default:
76       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %" PetscInt_FMT "-dimensional transform", ndim);
77     }
78   }
79   /* transfer to GPU memory */
80   PetscCallCUDA(cudaMemcpy(devArray, x_array, sizeof(cufftComplex) * dim[ndim], cudaMemcpyHostToDevice));
81   /* execute transform */
82   PetscCallCUFFT(cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_INVERSE));
83   /* transfer from GPU memory */
84   PetscCallCUDA(cudaMemcpy(y_array, devArray, sizeof(cufftComplex) * dim[ndim], cudaMemcpyDeviceToHost));
85   PetscCall(VecRestoreArray(y, &y_array));
86   PetscCall(VecRestoreArray(x, &x_array));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 static PetscErrorCode MatDestroy_SeqCUFFT(Mat A)
91 {
92   Mat_CUFFT *cufft = (Mat_CUFFT *)A->data;
93 
94   PetscFunctionBegin;
95   PetscCall(PetscFree(cufft->dim));
96   if (cufft->p_forward) PetscCallCUFFT(cufftDestroy(cufft->p_forward));
97   if (cufft->p_backward) PetscCallCUFFT(cufftDestroy(cufft->p_backward));
98   PetscCallCUDA(cudaFree(cufft->devArray));
99   PetscCall(PetscFree(A->data));
100   PetscCall(PetscObjectChangeTypeName((PetscObject)A, 0));
101   PetscFunctionReturn(PETSC_SUCCESS);
102 }
103 
104 /*@
105   MatCreateSeqCUFFT - Creates a matrix object that provides `MATSEQCUFFT` via the NVIDIA package CuFFT
106 
107   Collective
108 
109   Input Parameters:
110 + comm - MPI communicator, set to `PETSC_COMM_SELF`
111 . ndim - the ndim-dimensional transform
112 - dim  - array of size `ndim`, dim[i] contains the vector length in the i-dimension
113 
114   Output Parameter:
115 . A - the matrix
116 
117   Options Database Key:
118 . -mat_cufft_plannerflags - set CuFFT planner flags
119 
120   Level: intermediate
121 
122 .seealso: [](ch_matrices), `Mat`, `MATSEQCUFFT`
123 @*/
124 PetscErrorCode MatCreateSeqCUFFT(MPI_Comm comm, PetscInt ndim, const PetscInt dim[], Mat *A)
125 {
126   Mat_CUFFT *cufft;
127   PetscInt   m = 1;
128 
129   PetscFunctionBegin;
130   PetscCheck(ndim >= 0, PETSC_COMM_SELF, PETSC_ERR_USER, "ndim %" PetscInt_FMT " must be > 0", ndim);
131   if (ndim) PetscAssertPointer(dim, 3);
132   PetscAssertPointer(A, 4);
133   PetscCall(MatCreate(comm, A));
134   for (PetscInt d = 0; d < ndim; ++d) {
135     PetscCheck(dim[d] >= 0, PETSC_COMM_SELF, PETSC_ERR_USER, "dim[%" PetscInt_FMT "]=%" PetscInt_FMT " must be > 0", d, dim[d]);
136     m *= dim[d];
137   }
138   PetscCall(MatSetSizes(*A, m, m, m, m));
139   PetscCall(PetscObjectChangeTypeName((PetscObject)*A, MATSEQCUFFT));
140 
141   PetscCall(PetscNew(&cufft));
142   (*A)->data = (void *)cufft;
143   PetscCall(PetscMalloc1(ndim + 1, &cufft->dim));
144   PetscCall(PetscArraycpy(cufft->dim, dim, ndim));
145 
146   cufft->ndim       = ndim;
147   cufft->p_forward  = 0;
148   cufft->p_backward = 0;
149   cufft->dim[ndim]  = m;
150 
151   /* GPU memory allocation */
152   PetscCallCUDA(cudaMalloc((void **)&cufft->devArray, sizeof(cufftComplex) * m));
153 
154   (*A)->ops->mult          = MatMult_SeqCUFFT;
155   (*A)->ops->multtranspose = MatMultTranspose_SeqCUFFT;
156   (*A)->assembled          = PETSC_TRUE;
157   (*A)->ops->destroy       = MatDestroy_SeqCUFFT;
158 
159   /* get runtime options ...what options????? */
160   PetscOptionsBegin(comm, ((PetscObject)(*A))->prefix, "CUFFT Options", "Mat");
161   PetscOptionsEnd();
162   PetscFunctionReturn(PETSC_SUCCESS);
163 }
164