xref: /petsc/src/mat/impls/cufft/cufft.cu (revision 503c0ea9b45bcfbcebbb1ea5341243bbc69f0bea)
1 
2 /*
3     Provides an interface to the CUFFT package.
4     Testing examples can be found in ~src/mat/tests
5 */
6 
7 #include <petscdevice.h>
8 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
9 
10 typedef struct {
11   PetscInt     ndim;
12   PetscInt     *dim;
13   cufftHandle  p_forward, p_backward;
14   cufftComplex *devArray;
15 } Mat_CUFFT;
16 
17 PetscErrorCode MatMult_SeqCUFFT(Mat A, Vec x, Vec y)
18 {
19   Mat_CUFFT    *cufft    = (Mat_CUFFT*) A->data;
20   cufftComplex *devArray = cufft->devArray;
21   PetscInt      ndim     = cufft->ndim, *dim = cufft->dim;
22   PetscScalar  *x_array, *y_array;
23 
24   PetscFunctionBegin;
25   PetscCall(VecGetArray(x, &x_array));
26   PetscCall(VecGetArray(y, &y_array));
27   if (!cufft->p_forward) {
28     /* create a plan, then execute it */
29     switch (ndim) {
30     case 1:
31       PetscCallCUFFT(cufftPlan1d(&cufft->p_forward, dim[0], CUFFT_C2C, 1));
32       break;
33     case 2:
34       PetscCallCUFFT(cufftPlan2d(&cufft->p_forward, dim[0], dim[1], CUFFT_C2C));
35       break;
36     case 3:
37       PetscCallCUFFT(cufftPlan3d(&cufft->p_forward, dim[0], dim[1], dim[2], CUFFT_C2C));
38       break;
39     default:
40       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %" PetscInt_FMT "-dimensional transform", ndim);
41     }
42   }
43   /* transfer to GPU memory */
44   PetscCallCUDA(cudaMemcpy(devArray, x_array, sizeof(cufftComplex)*dim[ndim], cudaMemcpyHostToDevice));
45   /* execute transform */
46   PetscCallCUFFT(cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_FORWARD));
47   /* transfer from GPU memory */
48   PetscCallCUDA(cudaMemcpy(y_array, devArray, sizeof(cufftComplex)*dim[ndim], cudaMemcpyDeviceToHost));
49   PetscCall(VecRestoreArray(y, &y_array));
50   PetscCall(VecRestoreArray(x, &x_array));
51   PetscFunctionReturn(0);
52 }
53 
54 PetscErrorCode MatMultTranspose_SeqCUFFT(Mat A, Vec x, Vec y)
55 {
56   Mat_CUFFT    *cufft    = (Mat_CUFFT*) A->data;
57   cufftComplex *devArray = cufft->devArray;
58   PetscInt      ndim     = cufft->ndim, *dim = cufft->dim;
59   PetscScalar  *x_array, *y_array;
60 
61   PetscFunctionBegin;
62   PetscCall(VecGetArray(x, &x_array));
63   PetscCall(VecGetArray(y, &y_array));
64   if (!cufft->p_backward) {
65     /* create a plan, then execute it */
66     switch (ndim) {
67     case 1:
68       PetscCallCUFFT(cufftPlan1d(&cufft->p_backward, dim[0], CUFFT_C2C, 1));
69       break;
70     case 2:
71       PetscCallCUFFT(cufftPlan2d(&cufft->p_backward, dim[0], dim[1], CUFFT_C2C));
72       break;
73     case 3:
74       PetscCallCUFFT(cufftPlan3d(&cufft->p_backward, dim[0], dim[1], dim[2], CUFFT_C2C));
75       break;
76     default:
77       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %" PetscInt_FMT "-dimensional transform", ndim);
78     }
79   }
80   /* transfer to GPU memory */
81   PetscCallCUDA(cudaMemcpy(devArray, x_array, sizeof(cufftComplex)*dim[ndim], cudaMemcpyHostToDevice));
82   /* execute transform */
83   PetscCallCUFFT(cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_INVERSE));
84   /* transfer from GPU memory */
85   PetscCallCUDA(cudaMemcpy(y_array, devArray, sizeof(cufftComplex)*dim[ndim], cudaMemcpyDeviceToHost));
86   PetscCall(VecRestoreArray(y, &y_array));
87   PetscCall(VecRestoreArray(x, &x_array));
88   PetscFunctionReturn(0);
89 }
90 
91 PetscErrorCode MatDestroy_SeqCUFFT(Mat A)
92 {
93   Mat_CUFFT *cufft = (Mat_CUFFT*) A->data;
94 
95   PetscFunctionBegin;
96   PetscCall(PetscFree(cufft->dim));
97   if (cufft->p_forward)  PetscCallCUFFT(cufftDestroy(cufft->p_forward));
98   if (cufft->p_backward) PetscCallCUFFT(cufftDestroy(cufft->p_backward));
99   PetscCallCUDA(cudaFree(cufft->devArray));
100   PetscCall(PetscFree(A->data));
101   PetscCall(PetscObjectChangeTypeName((PetscObject)A,0));
102   PetscFunctionReturn(0);
103 }
104 
105 /*@
106   MatCreateSeqCUFFT - Creates a matrix object that provides sequential FFT via the external package CUFFT
107 
108   Collective
109 
110   Input Parameters:
111 + comm - MPI communicator, set to PETSC_COMM_SELF
112 . ndim - the ndim-dimensional transform
113 - dim  - array of size ndim, dim[i] contains the vector length in the i-dimension
114 
115   Output Parameter:
116 . A - the matrix
117 
118   Options Database Keys:
119 . -mat_cufft_plannerflags - set CUFFT planner flags
120 
121   Level: intermediate
122 @*/
123 PetscErrorCode  MatCreateSeqCUFFT(MPI_Comm comm, PetscInt ndim, const PetscInt dim[], Mat *A)
124 {
125   Mat_CUFFT *cufft;
126   PetscInt   m = 1;
127 
128   PetscFunctionBegin;
129   PetscCheck(ndim >= 0,PETSC_COMM_SELF, PETSC_ERR_USER, "ndim %" PetscInt_FMT " must be > 0", ndim);
130   if (ndim) PetscValidIntPointer(dim,3);
131   PetscValidPointer(A,4);
132   PetscCall(MatCreate(comm, A));
133   for (PetscInt d = 0; d < ndim; ++d) {
134     PetscCheck(dim[d] >= 0,PETSC_COMM_SELF, PETSC_ERR_USER, "dim[%" PetscInt_FMT "]=%" PetscInt_FMT " must be > 0", d, dim[d]);
135     m *= dim[d];
136   }
137   PetscCall(MatSetSizes(*A, m, m, m, m));
138   PetscCall(PetscObjectChangeTypeName((PetscObject)*A, MATSEQCUFFT));
139 
140   PetscCall(PetscNewLog(*A,&cufft));
141   (*A)->data = (void*) cufft;
142   PetscCall(PetscMalloc1(ndim+1, &cufft->dim));
143   PetscCall(PetscArraycpy(cufft->dim, dim, ndim));
144 
145   cufft->ndim       = ndim;
146   cufft->p_forward  = 0;
147   cufft->p_backward = 0;
148   cufft->dim[ndim]  = m;
149 
150   /* GPU memory allocation */
151   PetscCallCUDA(cudaMalloc((void**) &cufft->devArray, sizeof(cufftComplex)*m));
152 
153   (*A)->ops->mult          = MatMult_SeqCUFFT;
154   (*A)->ops->multtranspose = MatMultTranspose_SeqCUFFT;
155   (*A)->assembled          = PETSC_TRUE;
156   (*A)->ops->destroy       = MatDestroy_SeqCUFFT;
157 
158   /* get runtime options ...what options????? */
159   {
160     PetscErrorCode ierr;
161 
162     ierr = PetscOptionsBegin(comm, ((PetscObject)(*A))->prefix, "CUFFT Options", "Mat");PetscCall(ierr);
163     ierr = PetscOptionsEnd();PetscCall(ierr);
164   }
165   PetscFunctionReturn(0);
166 }
167