xref: /petsc/src/vec/vec/impls/mpi/kokkos/mpikok.kokkos.cxx (revision eefb368e60eb43c3e08ca0927788b33650449484)
1 /*
2    This file contains routines for Parallel vector operations.
3  */
4 #include <petsc_kokkos.hpp>
5 #include <petscvec_kokkos.hpp>
6 #include <petsc/private/deviceimpl.h>
7 #include <petsc/private/vecimpl.h>             /* for struct Vec */
8 #include <../src/vec/vec/impls/mpi/pvecimpl.h> /* for VecCreate/Destroy_MPI */
9 #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
10 #include <petscsf.h>
11 
VecDestroy_MPIKokkos(Vec v)12 static PetscErrorCode VecDestroy_MPIKokkos(Vec v)
13 {
14   PetscFunctionBegin;
15   delete static_cast<Vec_Kokkos *>(v->spptr);
16   PetscCall(VecDestroy_MPI(v));
17   PetscFunctionReturn(PETSC_SUCCESS);
18 }
19 
VecNorm_MPIKokkos(Vec xin,NormType type,PetscReal * z)20 static PetscErrorCode VecNorm_MPIKokkos(Vec xin, NormType type, PetscReal *z)
21 {
22   PetscFunctionBegin;
23   PetscCall(VecNorm_MPI_Default(xin, type, z, VecNorm_SeqKokkos));
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
VecErrorWeightedNorms_MPIKokkos(Vec U,Vec Y,Vec E,NormType wnormtype,PetscReal atol,Vec vatol,PetscReal rtol,Vec vrtol,PetscReal ignore_max,PetscReal * norm,PetscInt * norm_loc,PetscReal * norma,PetscInt * norma_loc,PetscReal * normr,PetscInt * normr_loc)27 static PetscErrorCode VecErrorWeightedNorms_MPIKokkos(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc)
28 {
29   PetscFunctionBegin;
30   PetscCall(VecErrorWeightedNorms_MPI_Default(U, Y, E, wnormtype, atol, vatol, rtol, vrtol, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc, VecErrorWeightedNorms_SeqKokkos));
31   PetscFunctionReturn(PETSC_SUCCESS);
32 }
33 
34 /* z = y^H x */
VecDot_MPIKokkos(Vec xin,Vec yin,PetscScalar * z)35 static PetscErrorCode VecDot_MPIKokkos(Vec xin, Vec yin, PetscScalar *z)
36 {
37   PetscFunctionBegin;
38   PetscCall(VecXDot_MPI_Default(xin, yin, z, VecDot_SeqKokkos));
39   PetscFunctionReturn(PETSC_SUCCESS);
40 }
41 
42 /* z = y^T x */
VecTDot_MPIKokkos(Vec xin,Vec yin,PetscScalar * z)43 static PetscErrorCode VecTDot_MPIKokkos(Vec xin, Vec yin, PetscScalar *z)
44 {
45   PetscFunctionBegin;
46   PetscCall(VecXDot_MPI_Default(xin, yin, z, VecTDot_SeqKokkos));
47   PetscFunctionReturn(PETSC_SUCCESS);
48 }
49 
VecMDot_MPIKokkos(Vec xin,PetscInt nv,const Vec y[],PetscScalar * z)50 static PetscErrorCode VecMDot_MPIKokkos(Vec xin, PetscInt nv, const Vec y[], PetscScalar *z)
51 {
52   PetscFunctionBegin;
53   PetscCall(VecMXDot_MPI_Default(xin, nv, y, z, VecMDot_SeqKokkos));
54   PetscFunctionReturn(PETSC_SUCCESS);
55 }
56 
VecMTDot_MPIKokkos(Vec xin,PetscInt nv,const Vec y[],PetscScalar * z)57 static PetscErrorCode VecMTDot_MPIKokkos(Vec xin, PetscInt nv, const Vec y[], PetscScalar *z)
58 {
59   PetscFunctionBegin;
60   PetscCall(VecMXDot_MPI_Default(xin, nv, y, z, VecMTDot_SeqKokkos));
61   PetscFunctionReturn(PETSC_SUCCESS);
62 }
63 
VecMDot_MPIKokkos_GEMV(Vec xin,PetscInt nv,const Vec y[],PetscScalar * z)64 static PetscErrorCode VecMDot_MPIKokkos_GEMV(Vec xin, PetscInt nv, const Vec y[], PetscScalar *z)
65 {
66   PetscFunctionBegin;
67   PetscCall(VecMXDot_MPI_Default(xin, nv, y, z, VecMDot_SeqKokkos_GEMV));
68   PetscFunctionReturn(PETSC_SUCCESS);
69 }
70 
VecMTDot_MPIKokkos_GEMV(Vec xin,PetscInt nv,const Vec y[],PetscScalar * z)71 static PetscErrorCode VecMTDot_MPIKokkos_GEMV(Vec xin, PetscInt nv, const Vec y[], PetscScalar *z)
72 {
73   PetscFunctionBegin;
74   PetscCall(VecMXDot_MPI_Default(xin, nv, y, z, VecMTDot_SeqKokkos_GEMV));
75   PetscFunctionReturn(PETSC_SUCCESS);
76 }
77 
VecMax_MPIKokkos(Vec xin,PetscInt * idx,PetscReal * z)78 static PetscErrorCode VecMax_MPIKokkos(Vec xin, PetscInt *idx, PetscReal *z)
79 {
80   const MPI_Op ops[] = {MPIU_MAXLOC, MPIU_MAX};
81 
82   PetscFunctionBegin;
83   PetscCall(VecMinMax_MPI_Default(xin, idx, z, VecMax_SeqKokkos, ops));
84   PetscFunctionReturn(PETSC_SUCCESS);
85 }
86 
VecMin_MPIKokkos(Vec xin,PetscInt * idx,PetscReal * z)87 static PetscErrorCode VecMin_MPIKokkos(Vec xin, PetscInt *idx, PetscReal *z)
88 {
89   const MPI_Op ops[] = {MPIU_MINLOC, MPIU_MIN};
90 
91   PetscFunctionBegin;
92   PetscCall(VecMinMax_MPI_Default(xin, idx, z, VecMin_SeqKokkos, ops));
93   PetscFunctionReturn(PETSC_SUCCESS);
94 }
95 
96 static PetscErrorCode VecCreate_MPIKokkos_Common(Vec); // forward declaration
97 
VecDuplicate_MPIKokkos(Vec win,Vec * vv)98 static PetscErrorCode VecDuplicate_MPIKokkos(Vec win, Vec *vv)
99 {
100   Vec         v;
101   Vec_Kokkos *veckok;
102   Vec_MPI    *wdata = (Vec_MPI *)win->data;
103 
104   PetscScalarKokkosDualView w_dual;
105 
106   PetscFunctionBegin;
107   PetscCallCXX(w_dual = PetscScalarKokkosDualView("w_dual", win->map->n + wdata->nghost)); // Kokkos init's v_dual to zero
108 
109   /* Reuse VecDuplicate_MPI, which contains a lot of stuff */
110   PetscCall(VecDuplicateWithArray_MPI(win, w_dual.view_host().data(), &v)); /* after the call, v is a VECMPI */
111   PetscCall(PetscObjectChangeTypeName((PetscObject)v, VECMPIKOKKOS));
112   PetscCall(VecCreate_MPIKokkos_Common(v));
113   v->ops[0] = win->ops[0]; // always follow ops[] in win
114 
115   /* Build the Vec_Kokkos struct */
116   veckok         = new Vec_Kokkos(v->map->n, w_dual.view_host().data(), w_dual.view_device().data());
117   veckok->w_dual = w_dual;
118   v->spptr       = veckok;
119   *vv            = v;
120   PetscFunctionReturn(PETSC_SUCCESS);
121 }
122 
VecDotNorm2_MPIKokkos(Vec s,Vec t,PetscScalar * dp,PetscScalar * nm)123 static PetscErrorCode VecDotNorm2_MPIKokkos(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm)
124 {
125   PetscFunctionBegin;
126   PetscCall(VecDotNorm2_MPI_Default(s, t, dp, nm, VecDotNorm2_SeqKokkos));
127   PetscFunctionReturn(PETSC_SUCCESS);
128 }
129 
VecGetSubVector_MPIKokkos(Vec x,IS is,Vec * y)130 static PetscErrorCode VecGetSubVector_MPIKokkos(Vec x, IS is, Vec *y)
131 {
132   PetscFunctionBegin;
133   PetscCall(VecGetSubVector_Kokkos_Private(x, PETSC_TRUE, is, y));
134   PetscFunctionReturn(PETSC_SUCCESS);
135 }
136 
VecSetPreallocationCOO_MPIKokkos(Vec x,PetscCount ncoo,const PetscInt coo_i[])137 static PetscErrorCode VecSetPreallocationCOO_MPIKokkos(Vec x, PetscCount ncoo, const PetscInt coo_i[])
138 {
139   const auto vecmpi = static_cast<Vec_MPI *>(x->data);
140   const auto veckok = static_cast<Vec_Kokkos *>(x->spptr);
141   PetscInt   m;
142 
143   PetscFunctionBegin;
144   PetscCall(VecGetLocalSize(x, &m));
145   PetscCall(VecSetPreallocationCOO_MPI(x, ncoo, coo_i));
146   PetscCall(veckok->SetUpCOO(vecmpi, m));
147   PetscFunctionReturn(PETSC_SUCCESS);
148 }
149 
VecSetValuesCOO_MPIKokkos(Vec x,const PetscScalar v[],InsertMode imode)150 static PetscErrorCode VecSetValuesCOO_MPIKokkos(Vec x, const PetscScalar v[], InsertMode imode)
151 {
152   const auto                  vecmpi  = static_cast<Vec_MPI *>(x->data);
153   const auto                  veckok  = static_cast<Vec_Kokkos *>(x->spptr);
154   const PetscCountKokkosView &jmap1   = veckok->jmap1_d;
155   const PetscCountKokkosView &perm1   = veckok->perm1_d;
156   const PetscCountKokkosView &imap2   = veckok->imap2_d;
157   const PetscCountKokkosView &jmap2   = veckok->jmap2_d;
158   const PetscCountKokkosView &perm2   = veckok->perm2_d;
159   const PetscCountKokkosView &Cperm   = veckok->Cperm_d;
160   PetscScalarKokkosView      &sendbuf = veckok->sendbuf_d;
161   PetscScalarKokkosView      &recvbuf = veckok->recvbuf_d;
162   PetscScalarKokkosView       xv;
163   ConstPetscScalarKokkosView  vv;
164   PetscMemType                memtype;
165   PetscInt                    m;
166 
167   PetscFunctionBegin;
168   PetscCall(VecGetLocalSize(x, &m));
169   PetscCall(PetscGetMemType(v, &memtype));
170   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
171     vv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscScalarKokkosViewHost(const_cast<PetscScalar *>(v), vecmpi->coo_n));
172   } else {
173     vv = ConstPetscScalarKokkosView(v, vecmpi->coo_n); /* Directly use v[]'s memory */
174   }
175 
176   /* Pack entries to be sent to remote */
177   Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vecmpi->sendlen), KOKKOS_LAMBDA(const PetscCount i) { sendbuf(i) = vv(Cperm(i)); });
178   PetscCall(PetscSFReduceWithMemTypeBegin(vecmpi->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, sendbuf.data(), PETSC_MEMTYPE_KOKKOS, recvbuf.data(), MPI_REPLACE));
179 
180   if (imode == INSERT_VALUES) PetscCall(VecGetKokkosViewWrite(x, &xv)); /* write vector */
181   else PetscCall(VecGetKokkosView(x, &xv));                             /* read & write vector */
182 
183   Kokkos::parallel_for(
184     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, m), KOKKOS_LAMBDA(const PetscCount i) {
185       PetscScalar sum = 0.0;
186       for (PetscCount k = jmap1(i); k < jmap1(i + 1); k++) sum += vv(perm1(k));
187       xv(i) = (imode == INSERT_VALUES ? 0.0 : xv(i)) + sum;
188     });
189 
190   PetscCall(PetscSFReduceEnd(vecmpi->coo_sf, MPIU_SCALAR, sendbuf.data(), recvbuf.data(), MPI_REPLACE));
191 
192   /* Add received remote entries */
193   Kokkos::parallel_for(
194     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, vecmpi->nnz2), KOKKOS_LAMBDA(PetscCount i) {
195       for (PetscCount k = jmap2(i); k < jmap2(i + 1); k++) xv(imap2(i)) += recvbuf(perm2(k));
196     });
197 
198   if (imode == INSERT_VALUES) PetscCall(VecRestoreKokkosViewWrite(x, &xv));
199   else PetscCall(VecRestoreKokkosView(x, &xv));
200   PetscFunctionReturn(PETSC_SUCCESS);
201 }
202 
203 // Shared by all VecCreate/Duplicate routines for VecMPIKokkos
VecCreate_MPIKokkos_Common(Vec v)204 static PetscErrorCode VecCreate_MPIKokkos_Common(Vec v)
205 {
206   PetscFunctionBegin;
207   v->ops->abs             = VecAbs_SeqKokkos;
208   v->ops->reciprocal      = VecReciprocal_SeqKokkos;
209   v->ops->pointwisemult   = VecPointwiseMult_SeqKokkos;
210   v->ops->setrandom       = VecSetRandom_SeqKokkos;
211   v->ops->dotnorm2        = VecDotNorm2_MPIKokkos;
212   v->ops->waxpy           = VecWAXPY_SeqKokkos;
213   v->ops->norm            = VecNorm_MPIKokkos;
214   v->ops->min             = VecMin_MPIKokkos;
215   v->ops->max             = VecMax_MPIKokkos;
216   v->ops->sum             = VecSum_SeqKokkos;
217   v->ops->shift           = VecShift_SeqKokkos;
218   v->ops->scale           = VecScale_SeqKokkos;
219   v->ops->copy            = VecCopy_SeqKokkos;
220   v->ops->set             = VecSet_SeqKokkos;
221   v->ops->swap            = VecSwap_SeqKokkos;
222   v->ops->axpy            = VecAXPY_SeqKokkos;
223   v->ops->axpby           = VecAXPBY_SeqKokkos;
224   v->ops->maxpy           = VecMAXPY_SeqKokkos;
225   v->ops->aypx            = VecAYPX_SeqKokkos;
226   v->ops->axpbypcz        = VecAXPBYPCZ_SeqKokkos;
227   v->ops->pointwisedivide = VecPointwiseDivide_SeqKokkos;
228   v->ops->placearray      = VecPlaceArray_SeqKokkos;
229   v->ops->replacearray    = VecReplaceArray_SeqKokkos;
230   v->ops->resetarray      = VecResetArray_SeqKokkos;
231 
232   v->ops->dot   = VecDot_MPIKokkos;
233   v->ops->tdot  = VecTDot_MPIKokkos;
234   v->ops->mdot  = VecMDot_MPIKokkos;
235   v->ops->mtdot = VecMTDot_MPIKokkos;
236 
237   v->ops->dot_local   = VecDot_SeqKokkos;
238   v->ops->tdot_local  = VecTDot_SeqKokkos;
239   v->ops->mdot_local  = VecMDot_SeqKokkos;
240   v->ops->mtdot_local = VecMTDot_SeqKokkos;
241 
242   v->ops->norm_local              = VecNorm_SeqKokkos;
243   v->ops->duplicate               = VecDuplicate_MPIKokkos;
244   v->ops->destroy                 = VecDestroy_MPIKokkos;
245   v->ops->getlocalvector          = VecGetLocalVector_SeqKokkos;
246   v->ops->restorelocalvector      = VecRestoreLocalVector_SeqKokkos;
247   v->ops->getlocalvectorread      = VecGetLocalVector_SeqKokkos;
248   v->ops->restorelocalvectorread  = VecRestoreLocalVector_SeqKokkos;
249   v->ops->getarraywrite           = VecGetArrayWrite_SeqKokkos;
250   v->ops->getarray                = VecGetArray_SeqKokkos;
251   v->ops->restorearray            = VecRestoreArray_SeqKokkos;
252   v->ops->getarrayandmemtype      = VecGetArrayAndMemType_SeqKokkos;
253   v->ops->restorearrayandmemtype  = VecRestoreArrayAndMemType_SeqKokkos;
254   v->ops->getarraywriteandmemtype = VecGetArrayWriteAndMemType_SeqKokkos;
255   v->ops->getsubvector            = VecGetSubVector_MPIKokkos;
256   v->ops->restoresubvector        = VecRestoreSubVector_SeqKokkos;
257 
258   v->ops->setpreallocationcoo = VecSetPreallocationCOO_MPIKokkos;
259   v->ops->setvaluescoo        = VecSetValuesCOO_MPIKokkos;
260 
261   v->ops->errorwnorm = VecErrorWeightedNorms_MPIKokkos;
262 
263   v->offloadmask = PETSC_OFFLOAD_KOKKOS; // Mark this is a VECKOKKOS; We use this flag for cheap VECKOKKOS test.
264   PetscFunctionReturn(PETSC_SUCCESS);
265 }
266 
VecConvert_MPI_MPIKokkos_inplace(Vec v)267 PETSC_INTERN PetscErrorCode VecConvert_MPI_MPIKokkos_inplace(Vec v)
268 {
269   Vec_MPI *vecmpi;
270 
271   PetscFunctionBegin;
272   PetscCall(PetscKokkosInitializeCheck());
273   PetscCall(PetscLayoutSetUp(v->map));
274   PetscCall(PetscObjectChangeTypeName((PetscObject)v, VECMPIKOKKOS));
275   PetscCall(VecCreate_MPIKokkos_Common(v));
276   PetscCheck(!v->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "v->spptr not NULL");
277   vecmpi = static_cast<Vec_MPI *>(v->data);
278   PetscCallCXX(v->spptr = new Vec_Kokkos(v->map->n, vecmpi->array, NULL));
279   PetscFunctionReturn(PETSC_SUCCESS);
280 }
281 
282 // Duplicate a VECMPIKOKKOS
VecDuplicateVecs_MPIKokkos_GEMV(Vec w,PetscInt m,Vec * V[])283 static PetscErrorCode VecDuplicateVecs_MPIKokkos_GEMV(Vec w, PetscInt m, Vec *V[])
284 {
285   PetscInt64                lda; // use 64-bit as we will do "m * lda"
286   PetscScalar              *array_h, *array_d;
287   PetscLayout               map;
288   Vec_MPI                  *wmpi = (Vec_MPI *)w->data;
289   PetscScalarKokkosDualView w_dual;
290 
291   PetscFunctionBegin;
292   PetscCall(PetscKokkosInitializeCheck()); // as we'll call kokkos_malloc()
293   if (wmpi->nghost) {                      // currently only do GEMV optimization for vectors without ghosts
294     w->ops->duplicatevecs = VecDuplicateVecs_Default;
295     PetscCall(VecDuplicateVecs(w, m, V));
296   } else {
297     PetscCall(PetscMalloc1(m, V));
298     PetscCall(VecGetLayout(w, &map));
299     VecGetLocalSizeAligned(w, 64, &lda); // get in lda the 64-bytes aligned local size
300 
301     // See comments in VecCreate_SeqKokkos() on why we use DualView to allocate the memory
302     PetscCallCXX(w_dual = PetscScalarKokkosDualView("VecDuplicateVecs", m * lda)); // Kokkos init's w_dual to zero
303 
304     // create the m vectors with raw arrays
305     array_h = w_dual.view_host().data();
306     array_d = w_dual.view_device().data();
307     for (PetscInt i = 0; i < m; i++) {
308       Vec v;
309       PetscCall(VecCreateMPIKokkosWithLayoutAndArrays_Private(map, &array_h[i * lda], &array_d[i * lda], &v));
310       PetscCallCXX(static_cast<Vec_Kokkos *>(v->spptr)->v_dual.modify_host()); // as we only init'ed array_h
311       PetscCall(PetscObjectListDuplicate(((PetscObject)w)->olist, &((PetscObject)v)->olist));
312       PetscCall(PetscFunctionListDuplicate(((PetscObject)w)->qlist, &((PetscObject)v)->qlist));
313       v->ops[0]             = w->ops[0];
314       v->stash.donotstash   = w->stash.donotstash;
315       v->stash.ignorenegidx = w->stash.ignorenegidx;
316       v->stash.bs           = w->stash.bs;
317       v->bstash.bs          = w->bstash.bs;
318       (*V)[i]               = v;
319     }
320 
321     // let the first vector own the raw arrays, so when it is destroyed it will free the arrays
322     if (m) {
323       Vec v = (*V)[0];
324 
325       static_cast<Vec_Kokkos *>(v->spptr)->w_dual = w_dual; // stash the memory
326       // disable replacearray of the first vector, as freeing its memory also frees others in the group.
327       // But replacearray of others is ok, as they don't own their array.
328       if (m > 1) v->ops->replacearray = VecReplaceArray_Default_GEMV_Error;
329     }
330   }
331   PetscFunctionReturn(PETSC_SUCCESS);
332 }
333 
334 /*MC
335    VECMPIKOKKOS - VECMPIKOKKOS = "mpikokkos" - The basic parallel vector, modified to use Kokkos
336 
337    Options Database Keys:
338 . -vec_type mpikokkos - sets the vector type to VECMPIKOKKOS during a call to VecSetFromOptions()
339 
340   Level: beginner
341 
342 .seealso: `VecCreate()`, `VecSetType()`, `VecSetFromOptions()`, `VecCreateMPIKokkosWithArray()`, `VECMPI`, `VecType`, `VecCreateMPI()`
343 M*/
VecCreate_MPIKokkos(Vec v)344 PetscErrorCode VecCreate_MPIKokkos(Vec v)
345 {
346   PetscBool                 mdot_use_gemv  = PETSC_TRUE;
347   PetscBool                 maxpy_use_gemv = PETSC_FALSE; // default is false as we saw bad performance with vendors' GEMV with tall skinny matrices.
348   PetscScalarKokkosDualView v_dual;
349 
350   PetscFunctionBegin;
351   PetscCall(PetscKokkosInitializeCheck());
352   PetscCall(PetscLayoutSetUp(v->map));
353 
354   PetscCallCXX(v_dual = PetscScalarKokkosDualView("v_dual", v->map->n)); // Kokkos init's v_dual to zero
355   PetscCall(VecCreate_MPI_Private(v, PETSC_FALSE, 0, v_dual.view_host().data()));
356 
357   PetscCall(PetscObjectChangeTypeName((PetscObject)v, VECMPIKOKKOS));
358   PetscCall(VecCreate_MPIKokkos_Common(v));
359   PetscCheck(!v->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "v->spptr not NULL");
360   PetscCallCXX(v->spptr = new Vec_Kokkos(v_dual));
361   PetscCall(PetscOptionsGetBool(NULL, NULL, "-vec_mdot_use_gemv", &mdot_use_gemv, NULL));
362   PetscCall(PetscOptionsGetBool(NULL, NULL, "-vec_maxpy_use_gemv", &maxpy_use_gemv, NULL));
363 
364   // allocate multiple vectors together
365   if (mdot_use_gemv || maxpy_use_gemv) v->ops[0].duplicatevecs = VecDuplicateVecs_MPIKokkos_GEMV;
366 
367   if (mdot_use_gemv) {
368     v->ops[0].mdot        = VecMDot_MPIKokkos_GEMV;
369     v->ops[0].mtdot       = VecMTDot_MPIKokkos_GEMV;
370     v->ops[0].mdot_local  = VecMDot_SeqKokkos_GEMV;
371     v->ops[0].mtdot_local = VecMTDot_SeqKokkos_GEMV;
372   }
373 
374   if (maxpy_use_gemv) v->ops[0].maxpy = VecMAXPY_SeqKokkos_GEMV;
375   PetscFunctionReturn(PETSC_SUCCESS);
376 }
377 
378 // Create a VECMPIKOKKOS with layout and arrays
VecCreateMPIKokkosWithLayoutAndArrays_Private(PetscLayout map,const PetscScalar harray[],const PetscScalar darray[],Vec * v)379 PetscErrorCode VecCreateMPIKokkosWithLayoutAndArrays_Private(PetscLayout map, const PetscScalar harray[], const PetscScalar darray[], Vec *v)
380 {
381   Vec w;
382 
383   PetscFunctionBegin;
384   if (map->n > 0) PetscCheck(darray, map->comm, PETSC_ERR_ARG_WRONG, "darray cannot be NULL");
385 #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
386   PetscCheck(harray == darray, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "harray and darray must be the same");
387 #endif
388   PetscCall(VecCreateMPIWithLayoutAndArray_Private(map, harray, &w));
389   PetscCall(PetscObjectChangeTypeName((PetscObject)w, VECMPIKOKKOS)); // Change it to VECKOKKOS
390   PetscCall(VecCreate_MPIKokkos_Common(w));
391   PetscCallCXX(w->spptr = new Vec_Kokkos(map->n, const_cast<PetscScalar *>(harray), const_cast<PetscScalar *>(darray)));
392   *v = w;
393   PetscFunctionReturn(PETSC_SUCCESS);
394 }
395 
396 /*@C
397   VecCreateMPIKokkosWithArray - Creates a parallel, array-style vector,
398   where the user provides the GPU array space to store the vector values.
399 
400   Collective
401 
402   Input Parameters:
403 + comm   - the MPI communicator to use
404 . bs     - block size, same meaning as VecSetBlockSize()
405 . n      - local vector length, cannot be PETSC_DECIDE
406 . N      - global vector length (or PETSC_DECIDE to have calculated)
407 - darray - the user provided GPU array to store the vector values
408 
409   Output Parameter:
410 . v - the vector
411 
412   Notes:
413   Use VecDuplicate() or VecDuplicateVecs() to form additional vectors of the
414   same type as an existing vector.
415 
416   If the user-provided array is NULL, then VecKokkosPlaceArray() can be used
417   at a later stage to SET the array for storing the vector values.
418 
419   PETSc does NOT free the array when the vector is destroyed via VecDestroy().
420   The user should not free the array until the vector is destroyed.
421 
422   Level: intermediate
423 
424 .seealso: `VecCreateSeqKokkosWithArray()`, `VecCreateMPIWithArray()`, `VecCreateSeqWithArray()`,
425           `VecCreate()`, `VecDuplicate()`, `VecDuplicateVecs()`, `VecCreateGhost()`,
426           `VecCreateMPI()`, `VecCreateGhostWithArray()`, `VecPlaceArray()`
427 
428 @*/
VecCreateMPIKokkosWithArray(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar darray[],Vec * v)429 PetscErrorCode VecCreateMPIKokkosWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar darray[], Vec *v)
430 {
431   Vec          w;
432   Vec_Kokkos  *veckok;
433   Vec_MPI     *vecmpi;
434   PetscScalar *harray;
435 
436   PetscFunctionBegin;
437   PetscCheck(n != PETSC_DECIDE, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must set local size of vector");
438   PetscCall(PetscKokkosInitializeCheck());
439   PetscCall(PetscSplitOwnership(comm, &n, &N));
440   PetscCall(VecCreate(comm, &w));
441   PetscCall(VecSetSizes(w, n, N));
442   PetscCall(VecSetBlockSize(w, bs));
443   PetscCall(PetscLayoutSetUp(w->map));
444 
445   if (std::is_same<DefaultMemorySpace, HostMirrorMemorySpace>::value) {
446     harray = const_cast<PetscScalar *>(darray);
447   } else PetscCall(PetscMalloc1(w->map->n, &harray)); /* If device is not the same as host, allocate the host array ourselves */
448 
449   PetscCall(VecCreate_MPI_Private(w, PETSC_FALSE /*alloc*/, 0 /*nghost*/, harray)); /* Build a sequential vector with provided data */
450   vecmpi = static_cast<Vec_MPI *>(w->data);
451 
452   if (!std::is_same<DefaultMemorySpace, HostMirrorMemorySpace>::value) vecmpi->array_allocated = harray; /* The host array was allocated by PETSc */
453 
454   PetscCall(PetscObjectChangeTypeName((PetscObject)w, VECMPIKOKKOS));
455   PetscCall(VecCreate_MPIKokkos_Common(w));
456   veckok = new Vec_Kokkos(n, harray, const_cast<PetscScalar *>(darray));
457   veckok->v_dual.modify_device(); /* Mark the device is modified */
458   w->spptr = static_cast<void *>(veckok);
459   *v       = w;
460   PetscFunctionReturn(PETSC_SUCCESS);
461 }
462 
463 /*
464    VecCreateMPIKokkosWithArrays_Private - Creates a Kokkos parallel, array-style vector
465    with user-provided arrays on host and device.
466 
467    Collective
468 
469    Input Parameter:
470 +  comm - the communicator
471 .  bs - the block size
472 .  n - the local vector length
473 .  N - the global vector length
474 -  harray - host memory where the vector elements are to be stored.
475 -  darray - device memory where the vector elements are to be stored.
476 
477    Output Parameter:
478 .  v - the vector
479 
480    Notes:
481    If there is no device, then harray and darray must be the same.
482    If n is not zero, then harray and darray must be allocated.
483    After the call, the created vector is supposed to be in a synchronized state, i.e.,
484    we suppose harray and darray have the same data.
485 
486    PETSc does NOT free the array when the vector is destroyed via VecDestroy().
487    The user should not free the array until the vector is destroyed.
488 */
VecCreateMPIKokkosWithArrays_Private(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar harray[],const PetscScalar darray[],Vec * v)489 PetscErrorCode VecCreateMPIKokkosWithArrays_Private(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar harray[], const PetscScalar darray[], Vec *v)
490 {
491   Vec w;
492 
493   PetscFunctionBegin;
494   PetscCall(PetscKokkosInitializeCheck());
495   if (n) {
496     PetscAssertPointer(harray, 5);
497     PetscCheck(darray, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "darray cannot be NULL");
498   }
499   if (std::is_same<DefaultMemorySpace, HostMirrorMemorySpace>::value) PetscCheck(harray == darray, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "harray and darray must be the same");
500   PetscCall(VecCreateMPIWithArray(comm, bs, n, N, harray, &w));
501   PetscCall(PetscObjectChangeTypeName((PetscObject)w, VECMPIKOKKOS)); /* Change it to Kokkos */
502   PetscCall(VecCreate_MPIKokkos_Common(w));
503   PetscCallCXX(w->spptr = new Vec_Kokkos(n, const_cast<PetscScalar *>(harray), const_cast<PetscScalar *>(darray)));
504   *v = w;
505   PetscFunctionReturn(PETSC_SUCCESS);
506 }
507 
508 /*MC
509    VECKOKKOS - VECKOKKOS = "kokkos" - The basic vector, modified to use Kokkos
510 
511    Options Database Keys:
512 . -vec_type kokkos - sets the vector type to VECKOKKOS during a call to VecSetFromOptions()
513 
514   Level: beginner
515 
516 .seealso: `VecCreate()`, `VecSetType()`, `VecSetFromOptions()`, `VecCreateMPIKokkosWithArray()`, `VECMPI`, `VecType`, `VecCreateMPI()`
517 M*/
VecCreate_Kokkos(Vec v)518 PetscErrorCode VecCreate_Kokkos(Vec v)
519 {
520   PetscMPIInt size;
521 
522   PetscFunctionBegin;
523   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)v), &size));
524   if (size == 1) PetscCall(VecSetType(v, VECSEQKOKKOS));
525   else PetscCall(VecSetType(v, VECMPIKOKKOS));
526   PetscFunctionReturn(PETSC_SUCCESS);
527 }
528