1 #include <petsc/private/petschypre.h> 2 #include <petscdevice_cuda.h> 3 #include <../src/mat/impls/hypre/mhypre_kernels.hpp> 4 5 PetscErrorCode MatZeroRows_CUDA(PetscInt n, const PetscInt rows[], const HYPRE_Int i[], const HYPRE_Int j[], HYPRE_Complex a[], HYPRE_Complex diag) 6 { 7 const PetscInt blkDimX = 16, blkDimY = 32; 8 PetscInt gridDimX = (n + blkDimX - 1) / blkDimX; 9 cudaStream_t stream; 10 11 PetscFunctionBegin; 12 if (!n) PetscFunctionReturn(PETSC_SUCCESS); 13 PetscCall(PetscGetCurrentCUDAStream(&stream)); 14 ZeroRows<<<dim3(gridDimX, 1), dim3(blkDimX, blkDimY), 0, stream>>>(n, rows, i, j, a, diag); 15 PetscCallCUDA(cudaGetLastError()); 16 PetscFunctionReturn(PETSC_SUCCESS); 17 } 18