xref: /petsc/src/ksp/ksp/impls/hpddm/cuda/hpddm.cu (revision e2904c9e7f5770420d8e3513e4dbc610c98cfb7f)
1 #define HPDDM_MIXED_PRECISION 1
2 #include <petsc/private/petschpddm.h>
3 #include <petscdevice_cuda.h>
4 #include <thrust/device_ptr.h>
5 #include <thrust/copy.h>
6 
KSPSolve_HPDDM_CUDA_Private(KSP_HPDDM * data,const PetscScalar * b,PetscScalar * x,PetscInt n,MPI_Comm comm)7 PetscErrorCode KSPSolve_HPDDM_CUDA_Private(KSP_HPDDM *data, const PetscScalar *b, PetscScalar *x, PetscInt n, MPI_Comm comm)
8 {
9   const PetscInt N = data->op->getDof() * n;
10 #if PetscDefined(USE_REAL_DOUBLE)
11   typedef HPDDM::downscaled_type<PetscScalar> K;
12 #endif
13 #if PetscDefined(USE_REAL_SINGLE)
14   typedef HPDDM::upscaled_type<PetscScalar> K;
15 #endif
16 
17   PetscFunctionBegin; // TODO: remove all cudaMemcpy() once HPDDM::IterativeMethod::solve() handles device pointers
18   if (data->precision != PETSC_SCALAR_PRECISION) {
19     const thrust::device_ptr<const PetscScalar> db = thrust::device_pointer_cast(b);
20     const thrust::device_ptr<PetscScalar>       dx = thrust::device_pointer_cast(x);
21     K                                          *ptr, *host_ptr;
22     thrust::device_ptr<K>                       dptr[2];
23 
24     PetscCall(PetscMalloc1(2 * N, &host_ptr));
25     PetscCallCUDA(cudaMalloc((void **)&ptr, 2 * N * sizeof(K)));
26     dptr[0] = thrust::device_pointer_cast(ptr);
27     dptr[1] = thrust::device_pointer_cast(ptr + N);
28     thrust::copy_n(thrust::cuda::par.on(PetscDefaultCudaStream), db, N, dptr[0]);
29     thrust::copy_n(thrust::cuda::par.on(PetscDefaultCudaStream), dx, N, dptr[1]);
30     PetscCallCUDA(cudaMemcpy(host_ptr, ptr, 2 * N * sizeof(K), cudaMemcpyDeviceToHost));
31     PetscCall(HPDDM::IterativeMethod::solve(*data->op, host_ptr, host_ptr + N, n, comm));
32     PetscCallCUDA(cudaMemcpy(ptr + N, host_ptr + N, N * sizeof(K), cudaMemcpyHostToDevice));
33     thrust::copy_n(thrust::cuda::par.on(PetscDefaultCudaStream), dptr[1], N, dx);
34     PetscCallCUDA(cudaFree(ptr));
35     PetscCall(PetscFree(host_ptr));
36     PetscCall(PetscLogGpuToCpu(2 * N * sizeof(K)));
37     PetscCall(PetscLogCpuToGpu(N * sizeof(K)));
38   } else {
39     PetscScalar *host_ptr;
40 
41     PetscCall(PetscMalloc1(2 * N, &host_ptr));
42     PetscCallCUDA(cudaMemcpy(host_ptr, b, N * sizeof(PetscScalar), cudaMemcpyDeviceToHost));
43     PetscCallCUDA(cudaMemcpy(host_ptr + N, x, N * sizeof(PetscScalar), cudaMemcpyDeviceToHost));
44     PetscCall(HPDDM::IterativeMethod::solve(*data->op, host_ptr, host_ptr + N, n, comm));
45     PetscCallCUDA(cudaMemcpy(x, host_ptr + N, N * sizeof(PetscScalar), cudaMemcpyHostToDevice));
46     PetscCall(PetscFree(host_ptr));
47     PetscCall(PetscLogGpuToCpu(2 * N * sizeof(PetscScalar)));
48     PetscCall(PetscLogCpuToGpu(N * sizeof(PetscScalar)));
49   }
50   PetscFunctionReturn(PETSC_SUCCESS);
51 }
52