xref: /petsc/src/mat/impls/cufft/cufft.cu (revision df4cd43f92eaa320656440c40edb1046daee8f75)
1 
2 /*
3     Provides an interface to the CUFFT package.
4     Testing examples can be found in ~src/mat/tests
5 */
6 
7 #include <petscdevice_cuda.h>
8 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
9 
10 typedef struct {
11   PetscInt      ndim;
12   PetscInt     *dim;
13   cufftHandle   p_forward, p_backward;
14   cufftComplex *devArray;
15 } Mat_CUFFT;
16 
17 PetscErrorCode MatMult_SeqCUFFT(Mat A, Vec x, Vec y)
18 {
19   Mat_CUFFT    *cufft    = (Mat_CUFFT *)A->data;
20   cufftComplex *devArray = cufft->devArray;
21   PetscInt      ndim = cufft->ndim, *dim = cufft->dim;
22   PetscScalar  *x_array, *y_array;
23 
24   PetscFunctionBegin;
25   PetscCall(VecGetArray(x, &x_array));
26   PetscCall(VecGetArray(y, &y_array));
27   if (!cufft->p_forward) {
28     /* create a plan, then execute it */
29     switch (ndim) {
30     case 1:
31       PetscCallCUFFT(cufftPlan1d(&cufft->p_forward, dim[0], CUFFT_C2C, 1));
32       break;
33     case 2:
34       PetscCallCUFFT(cufftPlan2d(&cufft->p_forward, dim[0], dim[1], CUFFT_C2C));
35       break;
36     case 3:
37       PetscCallCUFFT(cufftPlan3d(&cufft->p_forward, dim[0], dim[1], dim[2], CUFFT_C2C));
38       break;
39     default:
40       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %" PetscInt_FMT "-dimensional transform", ndim);
41     }
42   }
43   /* transfer to GPU memory */
44   PetscCallCUDA(cudaMemcpy(devArray, x_array, sizeof(cufftComplex) * dim[ndim], cudaMemcpyHostToDevice));
45   /* execute transform */
46   PetscCallCUFFT(cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_FORWARD));
47   /* transfer from GPU memory */
48   PetscCallCUDA(cudaMemcpy(y_array, devArray, sizeof(cufftComplex) * dim[ndim], cudaMemcpyDeviceToHost));
49   PetscCall(VecRestoreArray(y, &y_array));
50   PetscCall(VecRestoreArray(x, &x_array));
51   PetscFunctionReturn(PETSC_SUCCESS);
52 }
53 
54 PetscErrorCode MatMultTranspose_SeqCUFFT(Mat A, Vec x, Vec y)
55 {
56   Mat_CUFFT    *cufft    = (Mat_CUFFT *)A->data;
57   cufftComplex *devArray = cufft->devArray;
58   PetscInt      ndim = cufft->ndim, *dim = cufft->dim;
59   PetscScalar  *x_array, *y_array;
60 
61   PetscFunctionBegin;
62   PetscCall(VecGetArray(x, &x_array));
63   PetscCall(VecGetArray(y, &y_array));
64   if (!cufft->p_backward) {
65     /* create a plan, then execute it */
66     switch (ndim) {
67     case 1:
68       PetscCallCUFFT(cufftPlan1d(&cufft->p_backward, dim[0], CUFFT_C2C, 1));
69       break;
70     case 2:
71       PetscCallCUFFT(cufftPlan2d(&cufft->p_backward, dim[0], dim[1], CUFFT_C2C));
72       break;
73     case 3:
74       PetscCallCUFFT(cufftPlan3d(&cufft->p_backward, dim[0], dim[1], dim[2], CUFFT_C2C));
75       break;
76     default:
77       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %" PetscInt_FMT "-dimensional transform", ndim);
78     }
79   }
80   /* transfer to GPU memory */
81   PetscCallCUDA(cudaMemcpy(devArray, x_array, sizeof(cufftComplex) * dim[ndim], cudaMemcpyHostToDevice));
82   /* execute transform */
83   PetscCallCUFFT(cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_INVERSE));
84   /* transfer from GPU memory */
85   PetscCallCUDA(cudaMemcpy(y_array, devArray, sizeof(cufftComplex) * dim[ndim], cudaMemcpyDeviceToHost));
86   PetscCall(VecRestoreArray(y, &y_array));
87   PetscCall(VecRestoreArray(x, &x_array));
88   PetscFunctionReturn(PETSC_SUCCESS);
89 }
90 
91 PetscErrorCode MatDestroy_SeqCUFFT(Mat A)
92 {
93   Mat_CUFFT *cufft = (Mat_CUFFT *)A->data;
94 
95   PetscFunctionBegin;
96   PetscCall(PetscFree(cufft->dim));
97   if (cufft->p_forward) PetscCallCUFFT(cufftDestroy(cufft->p_forward));
98   if (cufft->p_backward) PetscCallCUFFT(cufftDestroy(cufft->p_backward));
99   PetscCallCUDA(cudaFree(cufft->devArray));
100   PetscCall(PetscFree(A->data));
101   PetscCall(PetscObjectChangeTypeName((PetscObject)A, 0));
102   PetscFunctionReturn(PETSC_SUCCESS);
103 }
104 
105 /*@
106   MatCreateSeqCUFFT - Creates a matrix object that provides `MATSEQCUFFT` via the NVIDIA package CuFFT
107 
108   Collective
109 
110   Input Parameters:
111 + comm - MPI communicator, set to `PETSC_COMM_SELF`
112 . ndim - the ndim-dimensional transform
113 - dim  - array of size `ndim`, dim[i] contains the vector length in the i-dimension
114 
115   Output Parameter:
116 . A - the matrix
117 
118   Options Database Key:
119 . -mat_cufft_plannerflags - set CuFFT planner flags
120 
121   Level: intermediate
122 
123 .seealso: [](chapter_matrices), `Mat`, `MATSEQCUFFT`
124 @*/
125 PetscErrorCode MatCreateSeqCUFFT(MPI_Comm comm, PetscInt ndim, const PetscInt dim[], Mat *A)
126 {
127   Mat_CUFFT *cufft;
128   PetscInt   m = 1;
129 
130   PetscFunctionBegin;
131   PetscCheck(ndim >= 0, PETSC_COMM_SELF, PETSC_ERR_USER, "ndim %" PetscInt_FMT " must be > 0", ndim);
132   if (ndim) PetscValidIntPointer(dim, 3);
133   PetscValidPointer(A, 4);
134   PetscCall(MatCreate(comm, A));
135   for (PetscInt d = 0; d < ndim; ++d) {
136     PetscCheck(dim[d] >= 0, PETSC_COMM_SELF, PETSC_ERR_USER, "dim[%" PetscInt_FMT "]=%" PetscInt_FMT " must be > 0", d, dim[d]);
137     m *= dim[d];
138   }
139   PetscCall(MatSetSizes(*A, m, m, m, m));
140   PetscCall(PetscObjectChangeTypeName((PetscObject)*A, MATSEQCUFFT));
141 
142   PetscCall(PetscNew(&cufft));
143   (*A)->data = (void *)cufft;
144   PetscCall(PetscMalloc1(ndim + 1, &cufft->dim));
145   PetscCall(PetscArraycpy(cufft->dim, dim, ndim));
146 
147   cufft->ndim       = ndim;
148   cufft->p_forward  = 0;
149   cufft->p_backward = 0;
150   cufft->dim[ndim]  = m;
151 
152   /* GPU memory allocation */
153   PetscCallCUDA(cudaMalloc((void **)&cufft->devArray, sizeof(cufftComplex) * m));
154 
155   (*A)->ops->mult          = MatMult_SeqCUFFT;
156   (*A)->ops->multtranspose = MatMultTranspose_SeqCUFFT;
157   (*A)->assembled          = PETSC_TRUE;
158   (*A)->ops->destroy       = MatDestroy_SeqCUFFT;
159 
160   /* get runtime options ...what options????? */
161   PetscOptionsBegin(comm, ((PetscObject)(*A))->prefix, "CUFFT Options", "Mat");
162   PetscOptionsEnd();
163   PetscFunctionReturn(PETSC_SUCCESS);
164 }
165