1 #include <petscvec_kokkos.hpp>
2 #include <petsc_kokkos.hpp>
3 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
4 #include <petscdevice.h>
5 #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
6
7 struct PC_PBJacobi_Kokkos {
8 PetscScalarKokkosDualView diag_dual;
9
PC_PBJacobi_KokkosPC_PBJacobi_Kokkos10 PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h)
11 {
12 PetscScalarKokkosViewHost diag_h(diag_ptr_h, len);
13 auto diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h);
14 diag_dual = PetscScalarKokkosDualView(diag_d, diag_h);
15 }
16
UpdatePC_PBJacobi_Kokkos17 PetscErrorCode Update(const PetscScalar *diag_ptr_h)
18 {
19 PetscFunctionBegin;
20 PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
21 PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */
22 PetscCall(KokkosDualViewSyncDevice(diag_dual, PetscGetKokkosExecutionSpace()));
23 PetscFunctionReturn(PETSC_SUCCESS);
24 }
25 };
26
27 /* Make 'transpose' a template parameter instead of a function input parameter, so that
28 it will be a const in template instantiation and gets optimized out.
29 */
30 template <PetscBool transpose>
PCApplyOrTranspose_PBJacobi_Kokkos(PC pc,Vec x,Vec y)31 static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y)
32 {
33 PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
34 PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
35 ConstPetscScalarKokkosView xv;
36 PetscScalarKokkosView yv;
37 PetscScalarKokkosView Av = pckok->diag_dual.view_device();
38 const PetscInt bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs;
39 const char *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos";
40
41 PetscFunctionBegin;
42 PetscCall(PetscLogGpuTimeBegin());
43 VecErrorIfNotKokkos(x);
44 VecErrorIfNotKokkos(y);
45 PetscCall(VecGetKokkosView(x, &xv));
46 PetscCall(VecGetKokkosViewWrite(y, &yv));
47 PetscCallCXX(Kokkos::parallel_for(
48 label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) {
49 const PetscScalar *Ap, *xp;
50 PetscScalar *yp;
51 PetscInt i, j, k;
52
53 k = row / bs; /* k-th block */
54 i = row % bs; /* this thread deals with i-th row of the block */
55 Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */
56 xp = &xv(bs * k);
57 yp = &yv(bs * k);
58 /* multiply i-th row (column) with x */
59 yp[i] = 0.0;
60 for (j = 0; j < bs; j++) {
61 yp[i] += Ap[0] * xp[j];
62 Ap += (transpose ? 1 : bs); /* block is in column major order */
63 }
64 }));
65 PetscCall(VecRestoreKokkosView(x, &xv));
66 PetscCall(VecRestoreKokkosViewWrite(y, &yv));
67 PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */
68 PetscCall(PetscLogGpuTimeEnd());
69 PetscFunctionReturn(PETSC_SUCCESS);
70 }
71
PCDestroy_PBJacobi_Kokkos(PC pc)72 static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc)
73 {
74 PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
75
76 PetscFunctionBegin;
77 PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr));
78 PetscCall(PCDestroy_PBJacobi(pc));
79 PetscFunctionReturn(PETSC_SUCCESS);
80 }
81
PCSetUp_PBJacobi_Kokkos(PC pc,Mat diagPB)82 PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB)
83 {
84 PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
85 PetscInt len;
86
87 PetscFunctionBegin;
88 PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */
89 len = jac->bs * jac->bs * jac->mbs;
90 if (!jac->spptr) {
91 PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag)));
92 } else {
93 PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
94 PetscCall(pckok->Update(jac->diag));
95 }
96
97 pc->ops->apply = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>;
98 pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>;
99 pc->ops->destroy = PCDestroy_PBJacobi_Kokkos;
100 PetscFunctionReturn(PETSC_SUCCESS);
101 }
102