xref: /petsc/src/mat/impls/hypre/mhypre_kernels.hpp (revision 9dd11ecf0918283bb567d8b33a92f53ac4ea7840)
1*a4963045SJacob Faibussowitsch #pragma once
2a32e9c99SJunchao Zhang 
3a32e9c99SJunchao Zhang #include <../src/mat/impls/hypre/mhypre.h>
4a32e9c99SJunchao Zhang 
5a32e9c99SJunchao Zhang // Zero the specified n rows in rows[] of the hypre CSRMatrix (i, j, a) and replace the diagonal entry with diag
ZeroRows(PetscInt n,const PetscInt rows[],const HYPRE_Int i[],const HYPRE_Int j[],HYPRE_Complex a[],HYPRE_Complex diag)6a32e9c99SJunchao Zhang __global__ static void ZeroRows(PetscInt n, const PetscInt rows[], const HYPRE_Int i[], const HYPRE_Int j[], HYPRE_Complex a[], HYPRE_Complex diag)
7a32e9c99SJunchao Zhang {
8a32e9c99SJunchao Zhang   PetscInt k     = blockDim.x * blockIdx.x + threadIdx.x; // k-th entry in rows[]
9a32e9c99SJunchao Zhang   PetscInt c     = blockDim.y * blockIdx.y + threadIdx.y; // c-th nonzero in row rows[k]
10a32e9c99SJunchao Zhang   PetscInt gridx = gridDim.x * blockDim.x;
11a32e9c99SJunchao Zhang   PetscInt gridy = gridDim.y * blockDim.y;
12a32e9c99SJunchao Zhang   for (; k < n; k += gridx) {
13a32e9c99SJunchao Zhang     PetscInt r  = rows[k]; // r-th row of the matrix
14a32e9c99SJunchao Zhang     PetscInt nz = i[r + 1] - i[r];
15a32e9c99SJunchao Zhang     for (; c < nz; c += gridy) {
16a32e9c99SJunchao Zhang       if (r == j[i[r] + c]) a[i[r] + c] = diag;
17a32e9c99SJunchao Zhang       else a[i[r] + c] = 0.0;
18a32e9c99SJunchao Zhang     }
19a32e9c99SJunchao Zhang   }
20a32e9c99SJunchao Zhang }
21