xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision efbe7e8a80d07327753dbe0b33efee01e046af3f)
1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2 
3 #include <Kokkos_Core.hpp>
4 
5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6 using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7 using HostMemorySpace      = Kokkos::HostSpace;
8 
9 typedef Kokkos::View<char*,DeviceMemorySpace>       deviceBuffer_t;
10 typedef Kokkos::View<char*,HostMemorySpace>         HostBuffer_t;
11 
12 typedef Kokkos::View<const char*,DeviceMemorySpace> deviceConstBuffer_t;
13 typedef Kokkos::View<const char*,HostMemorySpace>   HostConstBuffer_t;
14 
15 /*====================================================================================*/
16 /*                             Regular operations                           */
17 /*====================================================================================*/
18 template<typename Type> struct Insert{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = y;             return old;}};
19 template<typename Type> struct Add   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x += y;             return old;}};
20 template<typename Type> struct Mult  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x *= y;             return old;}};
21 template<typename Type> struct Min   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMin(x,y); return old;}};
22 template<typename Type> struct Max   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMax(x,y); return old;}};
23 template<typename Type> struct LAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x && y;        return old;}};
24 template<typename Type> struct LOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x || y;        return old;}};
25 template<typename Type> struct LXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = !x != !y;      return old;}};
26 template<typename Type> struct BAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x & y;         return old;}};
27 template<typename Type> struct BOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x | y;         return old;}};
28 template<typename Type> struct BXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x ^ y;         return old;}};
29 template<typename PairType> struct Minloc {
30   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
31     PairType old = x;
32     if (y.first < x.first) x = y;
33     else if (y.first == x.first) x.second = PetscMin(x.second,y.second);
34     return old;
35   }
36 };
37 template<typename PairType> struct Maxloc {
38   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
39     PairType old = x;
40     if (y.first > x.first) x = y;
41     else if (y.first == x.first) x.second = PetscMin(x.second,y.second); /* See MPI MAXLOC */
42     return old;
43   }
44 };
45 
46 /*====================================================================================*/
47 /*                             Atomic operations                            */
48 /*====================================================================================*/
49 template<typename Type> struct AtomicInsert  {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_assign(&x,y);}};
50 template<typename Type> struct AtomicAdd     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_add(&x,y);}};
51 template<typename Type> struct AtomicBAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_and(&x,y);}};
52 template<typename Type> struct AtomicBOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_or (&x,y);}};
53 template<typename Type> struct AtomicBXOR    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_xor(&x,y);}};
54 template<typename Type> struct AtomicLAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=~0; Kokkos::atomic_and(&x,y?one:zero);}};
55 template<typename Type> struct AtomicLOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=1;  Kokkos::atomic_or (&x,y?one:zero);}};
56 template<typename Type> struct AtomicMult    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_mul(&x,y);}};
57 template<typename Type> struct AtomicMin     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_min(&x,y);}};
58 template<typename Type> struct AtomicMax     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_max(&x,y);}};
59 /* TODO: struct AtomicLXOR  */
60 template<typename Type> struct AtomicFetchAdd{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {return Kokkos::atomic_fetch_add(&x,y);}};
61 
62 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
63 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt,PetscInt tid)
64 {
65   PetscInt        i,j,k,m,n,r;
66   const PetscInt  *offset,*start,*dx,*dy,*X,*Y;
67 
68   n      = opt[0];
69   offset = opt + 1;
70   start  = opt + n + 2;
71   dx     = opt + 2*n + 2;
72   dy     = opt + 3*n + 2;
73   X      = opt + 5*n + 2;
74   Y      = opt + 6*n + 2;
75   for (r=0; r<n; r++) {if (tid < offset[r+1]) break;}
76   m = (tid - offset[r]);
77   k = m/(dx[r]*dy[r]);
78   j = (m - k*dx[r]*dy[r])/dx[r];
79   i = m - k*dx[r]*dy[r] - j*dx[r];
80 
81   return (start[r] + k*X[r]*Y[r] + j*X[r] + i);
82 }
83 
84 /*====================================================================================*/
85 /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
86 /*====================================================================================*/
87 
88 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
89    <Type> is PetscReal, which is the primitive type we operate on.
90    <bs>   is 16, which says <unit> contains 16 primitive types.
91    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
92    <EQ>   is 0, which is (bs == BS ? 1 : 0)
93 
94   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
95   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
96 */
97 template<typename Type,PetscInt BS,PetscInt EQ>
98 static PetscErrorCode Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,const void *data_,void *buf_)
99 {
100   const PetscInt          *iopt = opt ? opt->array : NULL;
101   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS; /* If EQ, then MBS will be a compile-time const */
102   const Type              *data = static_cast<const Type*>(data_);
103   Type                    *buf = static_cast<Type*>(buf_);
104   DeviceExecutionSpace    exec;
105 
106   PetscFunctionBegin;
107   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
108     /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
109        iopt == NULL && idx == NULL ==> the indices are contiguous;
110      */
111     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
112     PetscInt s = tid*MBS;
113     for (int i=0; i<MBS; i++) buf[s+i] = data[t+i];
114   });
115   PetscFunctionReturn(0);
116 }
117 
118 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
119 static PetscErrorCode UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data_,const void *buf_)
120 {
121   Op                      op;
122   const PetscInt          *iopt = opt ? opt->array : NULL;
123   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
124   Type                    *data = static_cast<Type*>(data_);
125   const Type              *buf = static_cast<const Type*>(buf_);
126   DeviceExecutionSpace    exec;
127 
128   PetscFunctionBegin;
129   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
130     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
131     PetscInt s = tid*MBS;
132     for (int i=0; i<MBS; i++) op(data[t+i],buf[s+i]);
133   });
134   PetscFunctionReturn(0);
135 }
136 
137 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
138 static PetscErrorCode FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data,void *buf)
139 {
140   Op                      op;
141   const PetscInt          *ropt = opt ? opt->array : NULL;
142   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
143   Type                    *rootdata = static_cast<Type*>(data),*leafbuf=static_cast<Type*>(buf);
144   DeviceExecutionSpace    exec;
145 
146   PetscFunctionBegin;
147   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
148     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (idx? idx[tid] : start+tid))*MBS;
149     PetscInt l = tid*MBS;
150     for (int i=0; i<MBS; i++) leafbuf[l+i] = op(rootdata[r+i],leafbuf[l+i]);
151   });
152   PetscFunctionReturn(0);
153 }
154 
155 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
156 static PetscErrorCode ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
157 {
158   PetscInt                srcx=0,srcy=0,srcX=0,srcY=0,dstx=0,dsty=0,dstX=0,dstY=0;
159   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS=M*BS;
160   const Type              *src = static_cast<const Type*>(src_);
161   Type                    *dst = static_cast<Type*>(dst_);
162   DeviceExecutionSpace    exec;
163 
164   PetscFunctionBegin;
165   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
166   if (srcOpt)       {srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL;}
167   else if (!srcIdx) {srcx = srcX = count; srcy = srcY = 1;}
168 
169   if (dstOpt)       {dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL;}
170   else if (!dstIdx) {dstx = dstX = count; dsty = dstY = 1;}
171 
172   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
173     PetscInt i,j,k,s,t;
174     Op       op;
175     if (!srcIdx) { /* src is in 3D */
176       k = tid/(srcx*srcy);
177       j = (tid - k*srcx*srcy)/srcx;
178       i = tid - k*srcx*srcy - j*srcx;
179       s = srcStart + k*srcX*srcY + j*srcX + i;
180     } else { /* src is contiguous */
181       s = srcIdx[tid];
182     }
183 
184     if (!dstIdx) { /* 3D */
185       k = tid/(dstx*dsty);
186       j = (tid - k*dstx*dsty)/dstx;
187       i = tid - k*dstx*dsty - j*dstx;
188       t = dstStart + k*dstX*dstY + j*dstX + i;
189     } else { /* contiguous */
190       t = dstIdx[tid];
191     }
192 
193     s *= MBS;
194     t *= MBS;
195     for (i=0; i<MBS; i++) op(dst[t+i],src[s+i]);
196   });
197   PetscFunctionReturn(0);
198 }
199 
200 /* Specialization for Insert since we may use memcpy */
201 template<typename Type,PetscInt BS,PetscInt EQ>
202 static PetscErrorCode ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
203 {
204   PetscErrorCode          ierr;
205   const Type              *src = static_cast<const Type*>(src_);
206   Type                    *dst = static_cast<Type*>(dst_);
207   DeviceExecutionSpace    exec;
208 
209   PetscFunctionBegin;
210   if (!count) PetscFunctionReturn(0);
211   /*src and dst are contiguous */
212   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
213     size_t sz = count*link->unitbytes;
214     deviceBuffer_t      dbuf(reinterpret_cast<char*>(dst+dstStart*link->bs),sz);
215     deviceConstBuffer_t sbuf(reinterpret_cast<const char*>(src+srcStart*link->bs),sz);
216     Kokkos::deep_copy(exec,dbuf,sbuf);
217   } else {
218     ierr = ScatterAndOp<Type,Insert<Type>,BS,EQ>(link,count,srcStart,srcOpt,srcIdx,src,dstStart,dstOpt,dstIdx,dst);CHKERRQ(ierr);
219   }
220   PetscFunctionReturn(0);
221 }
222 
223 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
224 static PetscErrorCode FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt *rootidx,void *rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt *leafidx,const void *leafdata_,void *leafupdate_)
225 {
226   Op                      op;
227   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS = M*BS;
228   const PetscInt          *ropt = rootopt ? rootopt->array : NULL;
229   const PetscInt          *lopt = leafopt ? leafopt->array : NULL;
230   Type                    *rootdata = static_cast<Type*>(rootdata_),*leafupdate = static_cast<Type*>(leafupdate_);
231   const Type              *leafdata = static_cast<const Type*>(leafdata_);
232   DeviceExecutionSpace    exec;
233 
234   PetscFunctionBegin;
235   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
236     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (rootidx? rootidx[tid] : rootstart+tid))*MBS;
237     PetscInt l = (lopt? MapTidToIndex(lopt,tid) : (leafidx? leafidx[tid] : leafstart+tid))*MBS;
238     for (int i=0; i<MBS; i++) leafupdate[l+i] = op(rootdata[r+i],leafdata[l+i]);
239   });
240   PetscFunctionReturn(0);
241 }
242 
243 /*====================================================================================*/
244 /*  Init various types and instantiate pack/unpack function pointers                  */
245 /*====================================================================================*/
246 template<typename Type,PetscInt BS,PetscInt EQ>
247 static void PackInit_RealType(PetscSFLink link)
248 {
249   /* Pack/unpack for remote communication */
250   link->d_Pack              = Pack<Type,BS,EQ>;
251   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
252   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
253   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
254   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>   ,BS,EQ>;
255   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>   ,BS,EQ>;
256   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
257   /* Scatter for local communication */
258   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>; /* Has special optimizations */
259   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>    ,BS,EQ>;
260   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>   ,BS,EQ>;
261   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>    ,BS,EQ>;
262   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>    ,BS,EQ>;
263   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add <Type>,BS,EQ>;
264   /* Atomic versions when there are data-race possibilities */
265   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>  ,BS,EQ>;
266   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>     ,BS,EQ>;
267   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>    ,BS,EQ>;
268   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>     ,BS,EQ>;
269   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>     ,BS,EQ>;
270   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;
271 
272   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
273   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
274   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
275   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
276   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
277   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
278 }
279 
280 template<typename Type,PetscInt BS,PetscInt EQ>
281 static void PackInit_IntegerType(PetscSFLink link)
282 {
283   link->d_Pack              = Pack<Type,BS,EQ>;
284   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type> ,BS,EQ>;
285   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>    ,BS,EQ>;
286   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>   ,BS,EQ>;
287   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>    ,BS,EQ>;
288   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>    ,BS,EQ>;
289   link->d_UnpackAndLAND     = UnpackAndOp<Type,LAND<Type>   ,BS,EQ>;
290   link->d_UnpackAndLOR      = UnpackAndOp<Type,LOR<Type>    ,BS,EQ>;
291   link->d_UnpackAndLXOR     = UnpackAndOp<Type,LXOR<Type>   ,BS,EQ>;
292   link->d_UnpackAndBAND     = UnpackAndOp<Type,BAND<Type>   ,BS,EQ>;
293   link->d_UnpackAndBOR      = UnpackAndOp<Type,BOR<Type>    ,BS,EQ>;
294   link->d_UnpackAndBXOR     = UnpackAndOp<Type,BXOR<Type>   ,BS,EQ>;
295   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>    ,BS,EQ>;
296 
297   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>;
298   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
299   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
300   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>   ,BS,EQ>;
301   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>   ,BS,EQ>;
302   link->d_ScatterAndLAND    = ScatterAndOp<Type,LAND<Type>  ,BS,EQ>;
303   link->d_ScatterAndLOR     = ScatterAndOp<Type,LOR<Type>   ,BS,EQ>;
304   link->d_ScatterAndLXOR    = ScatterAndOp<Type,LXOR<Type>  ,BS,EQ>;
305   link->d_ScatterAndBAND    = ScatterAndOp<Type,BAND<Type>  ,BS,EQ>;
306   link->d_ScatterAndBOR     = ScatterAndOp<Type,BOR<Type>   ,BS,EQ>;
307   link->d_ScatterAndBXOR    = ScatterAndOp<Type,BXOR<Type>  ,BS,EQ>;
308   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;
309 
310   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>,BS,EQ>;
311   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
312   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
313   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
314   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
315   link->da_UnpackAndLAND    = UnpackAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
316   link->da_UnpackAndLOR     = UnpackAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
317   link->da_UnpackAndBAND    = UnpackAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
318   link->da_UnpackAndBOR     = UnpackAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
319   link->da_UnpackAndBXOR    = UnpackAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
320   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;
321 
322   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
323   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
324   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
325   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
326   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
327   link->da_ScatterAndLAND   = ScatterAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
328   link->da_ScatterAndLOR    = ScatterAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
329   link->da_ScatterAndBAND   = ScatterAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
330   link->da_ScatterAndBOR    = ScatterAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
331   link->da_ScatterAndBXOR   = ScatterAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
332   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
333 }
334 
335 #if defined(PETSC_HAVE_COMPLEX)
336 template<typename Type,PetscInt BS,PetscInt EQ>
337 static void PackInit_ComplexType(PetscSFLink link)
338 {
339   link->d_Pack             = Pack<Type,BS,EQ>;
340   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
341   link->d_UnpackAndAdd     = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
342   link->d_UnpackAndMult    = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
343   link->d_FetchAndAdd      = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
344 
345   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
346   link->d_ScatterAndAdd    = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
347   link->d_ScatterAndMult   = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
348   link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;
349 
350   link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>;
351   link->da_UnpackAndAdd    = UnpackAndOp<Type,AtomicAdd<Type>    ,BS,EQ>;
352   link->da_UnpackAndMult   = UnpackAndOp<Type,AtomicMult<Type>   ,BS,EQ>;
353   link->da_FetchAndAdd     = FetchAndOp<Type,AtomicFetchAdd<Type>,BS,EQ>;
354 
355   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
356   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
357   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
358   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
359 }
360 #endif
361 
362 template<typename Type>
363 static void PackInit_PairType(PetscSFLink link)
364 {
365   link->d_Pack             = Pack<Type,1,1>;
366   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,1,1>;
367   link->d_UnpackAndMaxloc  = UnpackAndOp<Type,Maxloc<Type>,1,1>;
368   link->d_UnpackAndMinloc  = UnpackAndOp<Type,Minloc<Type>,1,1>;
369 
370   link->d_ScatterAndInsert = ScatterAndOp<Type,Insert<Type>,1,1>;
371   link->d_ScatterAndMaxloc = ScatterAndOp<Type,Maxloc<Type>,1,1>;
372   link->d_ScatterAndMinloc = ScatterAndOp<Type,Minloc<Type>,1,1>;
373   /* Atomics for pair types are not implemented yet */
374 }
375 
376 template<typename Type,PetscInt BS,PetscInt EQ>
377 static void PackInit_DumbType(PetscSFLink link)
378 {
379   link->d_Pack             = Pack<Type,BS,EQ>;
380   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
381   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
382   /* Atomics for dumb types are not implemented yet */
383 }
384 
385 /*
386   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
387   that one is not able to repeatedly create and destroy the object. SF's original design was each
388   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
389   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
390   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
391   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
392   object in Kokkos.
393 */
394 /*
395 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
396 {
397   PetscFunctionBegin;
398   PetscFunctionReturn(0);
399 }
400 */
401 
402 /* Some device-specific utilities */
403 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink link)
404 {
405   PetscFunctionBegin;
406   Kokkos::fence();
407   PetscFunctionReturn(0);
408 }
409 
410 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink link)
411 {
412   DeviceExecutionSpace    exec;
413   PetscFunctionBegin;
414   exec.fence();
415   PetscFunctionReturn(0);
416 }
417 
418 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink link,PetscMemType dstmtype,void* dst,PetscMemType srcmtype,const void*src,size_t n)
419 {
420   DeviceExecutionSpace    exec;
421 
422   PetscFunctionBegin;
423   if (!n) PetscFunctionReturn(0);
424   if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_HOST) {
425     PetscErrorCode ierr = PetscMemcpy(dst,src,n);CHKERRQ(ierr);
426   } else {
427     if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_HOST) {
428       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
429       HostConstBuffer_t    sbuf(static_cast<const char*>(src),n);
430       Kokkos::deep_copy(exec,dbuf,sbuf);
431       PetscErrorCode ierr = PetscLogCpuToGpu(n);CHKERRQ(ierr);
432     } else if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_DEVICE) {
433       HostBuffer_t         dbuf(static_cast<char*>(dst),n);
434       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
435       Kokkos::deep_copy(exec,dbuf,sbuf);
436       PetscErrorCode ierr = PetscLogGpuToCpu(n);CHKERRQ(ierr);
437     } else if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_DEVICE) {
438       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
439       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
440       Kokkos::deep_copy(exec,dbuf,sbuf);
441     }
442   }
443   PetscFunctionReturn(0);
444 }
445 
446 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype,size_t size,void** ptr)
447 {
448   PetscFunctionBegin;
449   if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscMalloc(size,ptr);CHKERRQ(ierr);}
450   else if (mtype == PETSC_MEMTYPE_DEVICE) {*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);}
451   else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d", (int)mtype);
452   PetscFunctionReturn(0);
453 }
454 
455 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype,void* ptr)
456 {
457   PetscFunctionBegin;
458   if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscFree(ptr);CHKERRQ(ierr);}
459   else if (mtype == PETSC_MEMTYPE_DEVICE) {Kokkos::kokkos_free<DeviceMemorySpace>(ptr);}
460   else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d",(int)mtype);
461   PetscFunctionReturn(0);
462 }
463 
464 /*====================================================================================*/
465 /*                Main driver to init MPI datatype on device                          */
466 /*====================================================================================*/
467 
468 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
469 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF sf,PetscSFLink link,MPI_Datatype unit)
470 {
471   PetscErrorCode ierr;
472   PetscInt       nSignedChar=0,nUnsignedChar=0,nInt=0,nPetscInt=0,nPetscReal=0;
473   PetscBool      is2Int,is2PetscInt;
474 #if defined(PETSC_HAVE_COMPLEX)
475   PetscInt       nPetscComplex=0;
476 #endif
477 
478   PetscFunctionBegin;
479   if (link->deviceinited) PetscFunctionReturn(0);
480   ierr = MPIPetsc_Type_compare_contig(unit,MPI_SIGNED_CHAR,  &nSignedChar);CHKERRQ(ierr);
481   ierr = MPIPetsc_Type_compare_contig(unit,MPI_UNSIGNED_CHAR,&nUnsignedChar);CHKERRQ(ierr);
482   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
483   ierr = MPIPetsc_Type_compare_contig(unit,MPI_INT,  &nInt);CHKERRQ(ierr);
484   ierr = MPIPetsc_Type_compare_contig(unit,MPIU_INT, &nPetscInt);CHKERRQ(ierr);
485   ierr = MPIPetsc_Type_compare_contig(unit,MPIU_REAL,&nPetscReal);CHKERRQ(ierr);
486 #if defined(PETSC_HAVE_COMPLEX)
487   ierr = MPIPetsc_Type_compare_contig(unit,MPIU_COMPLEX,&nPetscComplex);CHKERRQ(ierr);
488 #endif
489   ierr = MPIPetsc_Type_compare(unit,MPI_2INT,&is2Int);CHKERRQ(ierr);
490   ierr = MPIPetsc_Type_compare(unit,MPIU_2INT,&is2PetscInt);CHKERRQ(ierr);
491 
492   if (is2Int) {
493     PackInit_PairType<Kokkos::pair<int,int>>(link);
494   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
495     PackInit_PairType<Kokkos::pair<PetscInt,PetscInt>>(link);
496   } else if (nPetscReal) {
497     if      (nPetscReal == 8) PackInit_RealType<PetscReal,8,1>(link); else if (nPetscReal%8 == 0) PackInit_RealType<PetscReal,8,0>(link);
498     else if (nPetscReal == 4) PackInit_RealType<PetscReal,4,1>(link); else if (nPetscReal%4 == 0) PackInit_RealType<PetscReal,4,0>(link);
499     else if (nPetscReal == 2) PackInit_RealType<PetscReal,2,1>(link); else if (nPetscReal%2 == 0) PackInit_RealType<PetscReal,2,0>(link);
500     else if (nPetscReal == 1) PackInit_RealType<PetscReal,1,1>(link); else if (nPetscReal%1 == 0) PackInit_RealType<PetscReal,1,0>(link);
501   } else if (nPetscInt) {
502     if      (nPetscInt == 8) PackInit_IntegerType<PetscInt,8,1>(link); else if (nPetscInt%8 == 0) PackInit_IntegerType<PetscInt,8,0>(link);
503     else if (nPetscInt == 4) PackInit_IntegerType<PetscInt,4,1>(link); else if (nPetscInt%4 == 0) PackInit_IntegerType<PetscInt,4,0>(link);
504     else if (nPetscInt == 2) PackInit_IntegerType<PetscInt,2,1>(link); else if (nPetscInt%2 == 0) PackInit_IntegerType<PetscInt,2,0>(link);
505     else if (nPetscInt == 1) PackInit_IntegerType<PetscInt,1,1>(link); else if (nPetscInt%1 == 0) PackInit_IntegerType<PetscInt,1,0>(link);
506 #if defined(PETSC_USE_64BIT_INDICES)
507   } else if (nInt) {
508     if      (nInt == 8) PackInit_IntegerType<int,8,1>(link); else if (nInt%8 == 0) PackInit_IntegerType<int,8,0>(link);
509     else if (nInt == 4) PackInit_IntegerType<int,4,1>(link); else if (nInt%4 == 0) PackInit_IntegerType<int,4,0>(link);
510     else if (nInt == 2) PackInit_IntegerType<int,2,1>(link); else if (nInt%2 == 0) PackInit_IntegerType<int,2,0>(link);
511     else if (nInt == 1) PackInit_IntegerType<int,1,1>(link); else if (nInt%1 == 0) PackInit_IntegerType<int,1,0>(link);
512 #endif
513   } else if (nSignedChar) {
514     if      (nSignedChar == 8) PackInit_IntegerType<char,8,1>(link); else if (nSignedChar%8 == 0) PackInit_IntegerType<char,8,0>(link);
515     else if (nSignedChar == 4) PackInit_IntegerType<char,4,1>(link); else if (nSignedChar%4 == 0) PackInit_IntegerType<char,4,0>(link);
516     else if (nSignedChar == 2) PackInit_IntegerType<char,2,1>(link); else if (nSignedChar%2 == 0) PackInit_IntegerType<char,2,0>(link);
517     else if (nSignedChar == 1) PackInit_IntegerType<char,1,1>(link); else if (nSignedChar%1 == 0) PackInit_IntegerType<char,1,0>(link);
518   }  else if (nUnsignedChar) {
519     if      (nUnsignedChar == 8) PackInit_IntegerType<unsigned char,8,1>(link); else if (nUnsignedChar%8 == 0) PackInit_IntegerType<unsigned char,8,0>(link);
520     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char,4,1>(link); else if (nUnsignedChar%4 == 0) PackInit_IntegerType<unsigned char,4,0>(link);
521     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char,2,1>(link); else if (nUnsignedChar%2 == 0) PackInit_IntegerType<unsigned char,2,0>(link);
522     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char,1,1>(link); else if (nUnsignedChar%1 == 0) PackInit_IntegerType<unsigned char,1,0>(link);
523 #if defined(PETSC_HAVE_COMPLEX)
524   } else if (nPetscComplex) {
525     if      (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,1>(link); else if (nPetscComplex%8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,0>(link);
526     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,1>(link); else if (nPetscComplex%4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,0>(link);
527     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,1>(link); else if (nPetscComplex%2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,0>(link);
528     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,1>(link); else if (nPetscComplex%1 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,0>(link);
529 #endif
530   } else {
531     MPI_Aint lb,nbyte;
532     ierr = MPI_Type_get_extent(unit,&lb,&nbyte);CHKERRQ(ierr);
533     if (lb != 0) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_SUP,"Datatype with nonzero lower bound %ld\n",(long)lb);
534     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
535       if      (nbyte == 4) PackInit_DumbType<char,4,1>(link); else if (nbyte%4 == 0) PackInit_DumbType<char,4,0>(link);
536       else if (nbyte == 2) PackInit_DumbType<char,2,1>(link); else if (nbyte%2 == 0) PackInit_DumbType<char,2,0>(link);
537       else if (nbyte == 1) PackInit_DumbType<char,1,1>(link); else if (nbyte%1 == 0) PackInit_DumbType<char,1,0>(link);
538     } else {
539       nInt = nbyte / sizeof(int);
540       if      (nInt == 8) PackInit_DumbType<int,8,1>(link); else if (nInt%8 == 0) PackInit_DumbType<int,8,0>(link);
541       else if (nInt == 4) PackInit_DumbType<int,4,1>(link); else if (nInt%4 == 0) PackInit_DumbType<int,4,0>(link);
542       else if (nInt == 2) PackInit_DumbType<int,2,1>(link); else if (nInt%2 == 0) PackInit_DumbType<int,2,0>(link);
543       else if (nInt == 1) PackInit_DumbType<int,1,1>(link); else if (nInt%1 == 0) PackInit_DumbType<int,1,0>(link);
544     }
545   }
546 
547   if (!sf->use_default_stream) {
548    #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
549     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_SUP,"Non-default cuda/hip streams are not supported by the SF Kokkos backend. If it is cuda, use -sf_backend cuda instead");
550    #endif
551   }
552 
553   link->d_SyncDevice = PetscSFLinkSyncDevice_Kokkos;
554   link->d_SyncStream = PetscSFLinkSyncStream_Kokkos;
555   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
556   link->spptr        = NULL; /* Unused now */
557   link->Destroy      = NULL; /* PetscSFLinkDestroy_Kokkos; */
558   link->deviceinited = PETSC_TRUE;
559   PetscFunctionReturn(0);
560 }
561