xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision 7d5fd1e4d9337468ad3f05b65b7facdcd2dfd2a4)
1 #include <petsc/private/cudavecimpl.h>
2 #include <../src/vec/is/sf/impls/basic/sfpack.h>
3 #include <mpi.h>
4 #include <nvshmem.h>
5 #include <nvshmemx.h>
6 
7 PetscErrorCode PetscNvshmemInitializeCheck(void)
8 {
9   PetscErrorCode   ierr;
10 
11   PetscFunctionBegin;
12   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
13     nvshmemx_init_attr_t attr;
14     attr.mpi_comm = &PETSC_COMM_WORLD;
15     ierr = PetscCUDAInitializeCheck();CHKERRQ(ierr);
16     ierr = nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM,&attr);CHKERRQ(ierr);
17     PetscNvshmemInitialized = PETSC_TRUE;
18     PetscBeganNvshmem       = PETSC_TRUE;
19   }
20   PetscFunctionReturn(0);
21 }
22 
23 PetscErrorCode PetscNvshmemMalloc(size_t size, void** ptr)
24 {
25   PetscErrorCode ierr;
26 
27   PetscFunctionBegin;
28   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
29   *ptr = nvshmem_malloc(size);
30   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_malloc() failed to allocate %zu bytes",size);
31   PetscFunctionReturn(0);
32 }
33 
34 PetscErrorCode PetscNvshmemCalloc(size_t size, void**ptr)
35 {
36   PetscErrorCode ierr;
37 
38   PetscFunctionBegin;
39   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
40   *ptr = nvshmem_calloc(size,1);
41   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_calloc() failed to allocate %zu bytes",size);
42   PetscFunctionReturn(0);
43 }
44 
45 PetscErrorCode PetscNvshmemFree_Private(void* ptr)
46 {
47   PetscFunctionBegin;
48   nvshmem_free(ptr);
49   PetscFunctionReturn(0);
50 }
51 
52 PetscErrorCode PetscNvshmemFinalize(void)
53 {
54   PetscFunctionBegin;
55   nvshmem_finalize();
56   PetscFunctionReturn(0);
57 }
58 
59 /* Free nvshmem related fields in the SF */
60 PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
61 {
62   PetscErrorCode    ierr;
63   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
64 
65   PetscFunctionBegin;
66   ierr = PetscFree2(bas->leafsigdisp,bas->leafbufdisp);CHKERRQ(ierr);
67   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafbufdisp_d);CHKERRQ(ierr);
68   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafsigdisp_d);CHKERRQ(ierr);
69   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->iranks_d);CHKERRQ(ierr);
70   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->ioffset_d);CHKERRQ(ierr);
71 
72   ierr = PetscFree2(sf->rootsigdisp,sf->rootbufdisp);CHKERRQ(ierr);
73   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootbufdisp_d);CHKERRQ(ierr);
74   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootsigdisp_d);CHKERRQ(ierr);
75   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->ranks_d);CHKERRQ(ierr);
76   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->roffset_d);CHKERRQ(ierr);
77   PetscFunctionReturn(0);
78 }
79 
80 /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
81 static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
82 {
83   PetscErrorCode ierr;
84   cudaError_t    cerr;
85   PetscSF_Basic  *bas = (PetscSF_Basic*)sf->data;
86   PetscInt       i,nRemoteRootRanks,nRemoteLeafRanks;
87   PetscMPIInt    tag;
88   MPI_Comm       comm;
89   MPI_Request    *rootreqs,*leafreqs;
90   PetscInt       tmp,stmp[4],rtmp[4]; /* tmps for send/recv buffers */
91 
92   PetscFunctionBegin;
93   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
94   ierr = PetscObjectGetNewTag((PetscObject)sf,&tag);CHKERRQ(ierr);
95 
96   nRemoteRootRanks      = sf->nranks-sf->ndranks;
97   nRemoteLeafRanks      = bas->niranks-bas->ndiranks;
98   sf->nRemoteRootRanks  = nRemoteRootRanks;
99   bas->nRemoteLeafRanks = nRemoteLeafRanks;
100 
101   ierr = PetscMalloc2(nRemoteLeafRanks,&rootreqs,nRemoteRootRanks,&leafreqs);CHKERRQ(ierr);
102 
103   stmp[0] = nRemoteRootRanks;
104   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
105   stmp[2] = nRemoteLeafRanks;
106   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
107 
108   ierr = MPIU_Allreduce(stmp,rtmp,4,MPIU_INT,MPI_MAX,comm);CHKERRMPI(ierr);
109 
110   sf->nRemoteRootRanksMax   = rtmp[0];
111   sf->leafbuflen_rmax       = rtmp[1];
112   bas->nRemoteLeafRanksMax  = rtmp[2];
113   bas->rootbuflen_rmax      = rtmp[3];
114 
115   /* Total four rounds of MPI communications to set up the nvshmem fields */
116 
117   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
118   ierr = PetscMalloc2(nRemoteRootRanks,&sf->rootsigdisp,nRemoteRootRanks,&sf->rootbufdisp);CHKERRQ(ierr);
119   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootsigdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
120   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);} /* Roots send. Note i changes, so we use MPI_Send. */
121   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
122 
123   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootbufdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
124   for (i=0; i<nRemoteLeafRanks; i++) {
125     tmp  = bas->ioffset[i+bas->ndiranks] - bas->ioffset[bas->ndiranks];
126     ierr = MPI_Send(&tmp,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);  /* Roots send. Note tmp changes, so we use MPI_Send. */
127   }
128   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
129 
130   cerr = cudaMalloc((void**)&sf->rootbufdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
131   cerr = cudaMalloc((void**)&sf->rootsigdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
132   cerr = cudaMalloc((void**)&sf->ranks_d,nRemoteRootRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
133   cerr = cudaMalloc((void**)&sf->roffset_d,(nRemoteRootRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
134 
135   cerr = cudaMemcpyAsync(sf->rootbufdisp_d,sf->rootbufdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
136   cerr = cudaMemcpyAsync(sf->rootsigdisp_d,sf->rootsigdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
137   cerr = cudaMemcpyAsync(sf->ranks_d,sf->ranks+sf->ndranks,nRemoteRootRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
138   cerr = cudaMemcpyAsync(sf->roffset_d,sf->roffset+sf->ndranks,(nRemoteRootRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
139 
140   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
141   ierr = PetscMalloc2(nRemoteLeafRanks,&bas->leafsigdisp,nRemoteLeafRanks,&bas->leafbufdisp);CHKERRQ(ierr);
142   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafsigdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
143   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);}
144   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
145 
146   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafbufdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
147   for (i=0; i<nRemoteRootRanks; i++) {
148     tmp  = sf->roffset[i+sf->ndranks] - sf->roffset[sf->ndranks];
149     ierr = MPI_Send(&tmp,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);
150   }
151   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
152 
153   cerr = cudaMalloc((void**)&bas->leafbufdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
154   cerr = cudaMalloc((void**)&bas->leafsigdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
155   cerr = cudaMalloc((void**)&bas->iranks_d,nRemoteLeafRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
156   cerr = cudaMalloc((void**)&bas->ioffset_d,(nRemoteLeafRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
157 
158   cerr = cudaMemcpyAsync(bas->leafbufdisp_d,bas->leafbufdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
159   cerr = cudaMemcpyAsync(bas->leafsigdisp_d,bas->leafsigdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
160   cerr = cudaMemcpyAsync(bas->iranks_d,bas->iranks+bas->ndiranks,nRemoteLeafRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
161   cerr = cudaMemcpyAsync(bas->ioffset_d,bas->ioffset+bas->ndiranks,(nRemoteLeafRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
162 
163   ierr = PetscFree2(rootreqs,leafreqs);CHKERRQ(ierr);
164   PetscFunctionReturn(0);
165 }
166 
167 PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,PetscBool *use_nvshmem)
168 {
169   PetscErrorCode   ierr;
170   MPI_Comm         comm;
171   PetscBool        isBasic;
172   PetscMPIInt      result = MPI_UNEQUAL;
173 
174   PetscFunctionBegin;
175   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
176   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
177      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
178   */
179   sf->checked_nvshmem_eligibility = PETSC_TRUE;
180   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
181     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
182     ierr = PetscObjectTypeCompare((PetscObject)sf,PETSCSFBASIC,&isBasic);CHKERRQ(ierr);
183     if (isBasic) {ierr = MPI_Comm_compare(PETSC_COMM_WORLD,comm,&result);CHKERRMPI(ierr);}
184     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
185 
186     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
187        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
188        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
189     */
190     if (sf->use_nvshmem) {
191       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
192       ierr = MPI_Allreduce(MPI_IN_PLACE,&hasNullRank,1,MPIU_INT,MPI_LOR,comm);CHKERRMPI(ierr);
193       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
194     }
195     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
196   }
197 
198   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
199   if (sf->use_nvshmem) {
200     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
201     PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
202    #if defined(PETSC_USE_DEBUG)  /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
203     ierr = MPI_Allreduce(&oneCuda,&allCuda,1,MPIU_INT,MPI_LAND,comm);CHKERRMPI(ierr);
204     if (allCuda != oneCuda) SETERRQ(comm,PETSC_ERR_SUP,"root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
205    #endif
206     if (allCuda) {
207       ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
208       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
209         ierr = PetscSFSetUp_Basic_NVSHMEM(sf);CHKERRQ(ierr);
210         sf->setup_nvshmem = PETSC_TRUE;
211       }
212       *use_nvshmem = PETSC_TRUE;
213     } else {
214       *use_nvshmem = PETSC_FALSE;
215     }
216   } else {
217     *use_nvshmem = PETSC_FALSE;
218   }
219   PetscFunctionReturn(0);
220 }
221 
222 /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
223 static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
224 {
225   cudaError_t    cerr;
226   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
227   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
228 
229   PetscFunctionBegin;
230   if (buflen) {
231     cerr = cudaEventRecord(link->dataReady,link->stream);CHKERRCUDA(cerr);
232     cerr = cudaStreamWaitEvent(link->remoteCommStream,link->dataReady,0);CHKERRCUDA(cerr);
233   }
234   PetscFunctionReturn(0);
235 }
236 
237 /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
238 static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
239 {
240   cudaError_t    cerr;
241   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
242   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
243 
244   PetscFunctionBegin;
245   /* If unpack to non-null device buffer, build the endRemoteComm dependance */
246   if (buflen) {
247     cerr = cudaEventRecord(link->endRemoteComm,link->remoteCommStream);CHKERRCUDA(cerr);
248     cerr = cudaStreamWaitEvent(link->stream,link->endRemoteComm,0);CHKERRCUDA(cerr);
249   }
250   PetscFunctionReturn(0);
251 }
252 
253 /* Send/Put signals to remote ranks
254 
255  Input parameters:
256   + n        - Number of remote ranks
257   . sig      - Signal address in symmetric heap
258   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
259   . ranks    - remote ranks
260   - newval   - Set signals to this value
261 */
262 __global__ static void NvshmemSendSignals(PetscInt n,uint64_t *sig,PetscInt *sigdisp,PetscMPIInt *ranks,uint64_t newval)
263 {
264   int i = blockIdx.x*blockDim.x + threadIdx.x;
265 
266   /* Each thread puts one remote signal */
267   if (i < n) nvshmemx_uint64_signal(sig+sigdisp[i],newval,ranks[i]);
268 }
269 
270 /* Wait until local signals equal to the expected value and then set them to a new value
271 
272  Input parameters:
273   + n        - Number of signals
274   . sig      - Local signal address
275   . expval   - expected value
276   - newval   - Set signals to this new value
277 */
278 __global__ static void NvshmemWaitSignals(PetscInt n,uint64_t *sig,uint64_t expval,uint64_t newval)
279 {
280 #if 0
281   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
282   int i = blockIdx.x*blockDim.x + threadIdx.x;
283   if (i < n) {
284     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
285     sig[i] = newval;
286   }
287 #else
288   nvshmem_uint64_wait_until_all(sig,n,NULL/*no mask*/,NVSHMEM_CMP_EQ,expval);
289   for (int i=0; i<n; i++) sig[i] = newval;
290 #endif
291 }
292 
293 /* ===========================================================================================================
294 
295    A set of routines to support receiver initiated communication using the get method
296 
297     The getting protocol is:
298 
299     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
300     All signal variables have an initial value 0.
301 
302     Sender:                                 |  Receiver:
303   1.  Wait ssig be 0, then set it to 1
304   2.  Pack data into stand alone sbuf       |
305   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
306                                             |   2. Get data from remote sbuf to local rbuf
307                                             |   3. Put 1 to sender's ssig
308                                             |   4. Unpack data from local rbuf
309    ===========================================================================================================*/
310 /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
311    Sender waits for signals (from receivers) indicating receivers have finished getting data
312 */
313 PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
314 {
315   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
316   uint64_t          *sig;
317   PetscInt          n;
318 
319   PetscFunctionBegin;
320   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
321     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
322     n   = bas->nRemoteLeafRanks;
323   } else { /* LEAF2ROOT */
324     sig = link->leafSendSig;
325     n   = sf->nRemoteRootRanks;
326   }
327 
328   if (n) {
329     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(n,sig,0,1); /* wait the signals to be 0, then set them to 1 */
330     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
331   }
332   PetscFunctionReturn(0);
333 }
334 
335 /* n thread blocks. Each takes in charge one remote rank */
336 __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt *srcranks,const char *src,PetscInt *srcdisp,char *dst,PetscInt *dstdisp,PetscInt unitbytes)
337 {
338   int               bid = blockIdx.x;
339   PetscMPIInt       pe  = srcranks[bid];
340 
341   if (!nvshmem_ptr(src,pe)) {
342     PetscInt nelems = (dstdisp[bid+1]-dstdisp[bid])*unitbytes;
343     nvshmem_getmem_nbi(dst+(dstdisp[bid]-dstdisp[0])*unitbytes,src+srcdisp[bid]*unitbytes,nelems,pe);
344   }
345 }
346 
347 /* Start communication -- Get data in the given direction */
348 PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
349 {
350   PetscErrorCode    ierr;
351   cudaError_t       cerr;
352   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
353 
354   PetscInt          nsrcranks,ndstranks,nLocallyAccessible = 0;
355 
356   char              *src,*dst;
357   PetscInt          *srcdisp_h,*dstdisp_h;
358   PetscInt          *srcdisp_d,*dstdisp_d;
359   PetscMPIInt       *srcranks_h;
360   PetscMPIInt       *srcranks_d,*dstranks_d;
361   uint64_t          *dstsig;
362   PetscInt          *dstsigdisp_d;
363 
364   PetscFunctionBegin;
365   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
366   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
367     nsrcranks    = sf->nRemoteRootRanks;
368     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
369 
370     srcdisp_h    = sf->rootbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
371     srcdisp_d    = sf->rootbufdisp_d;
372     srcranks_h   = sf->ranks+sf->ndranks; /* my (remote) root ranks */
373     srcranks_d   = sf->ranks_d;
374 
375     ndstranks    = bas->nRemoteLeafRanks;
376     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
377 
378     dstdisp_h    = sf->roffset+sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
379     dstdisp_d    = sf->roffset_d;
380     dstranks_d   = bas->iranks_d; /* my (remote) leaf ranks */
381 
382     dstsig       = link->leafRecvSig;
383     dstsigdisp_d = bas->leafsigdisp_d;
384   } else { /* src is leaf, dst is root; we will move data from src to dst */
385     nsrcranks    = bas->nRemoteLeafRanks;
386     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
387 
388     srcdisp_h    = bas->leafbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
389     srcdisp_d    = bas->leafbufdisp_d;
390     srcranks_h   = bas->iranks+bas->ndiranks; /* my (remote) root ranks */
391     srcranks_d   = bas->iranks_d;
392 
393     ndstranks    = sf->nRemoteRootRanks;
394     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
395 
396     dstdisp_h    = bas->ioffset+bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
397     dstdisp_d    = bas->ioffset_d;
398     dstranks_d   = sf->ranks_d; /* my (remote) root ranks */
399 
400     dstsig       = link->rootRecvSig;
401     dstsigdisp_d = sf->rootsigdisp_d;
402   }
403 
404   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
405   if (ndstranks) {
406     NvshmemSendSignals<<<(ndstranks+255)/256,256,0,link->remoteCommStream>>>(ndstranks,dstsig,dstsigdisp_d,dstranks_d,1); /* set signals to 1 */
407     cerr = cudaGetLastError();CHKERRCUDA(cerr);
408   }
409 
410   /* dst waits for signals (permissions) from src ranks to start getting data */
411   if (nsrcranks) {
412     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(nsrcranks,dstsig,1,0); /* wait the signals to be 1, then set them to 0 */
413     cerr = cudaGetLastError();CHKERRCUDA(cerr);
414   }
415 
416   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
417 
418   /* Count number of locally accessible src ranks, which should be a small number */
419   for (int i=0; i<nsrcranks; i++) {if (nvshmem_ptr(src,srcranks_h[i])) nLocallyAccessible++;}
420 
421   /* Get data from remotely accessible PEs */
422   if (nLocallyAccessible < nsrcranks) {
423     GetDataFromRemotelyAccessible<<<nsrcranks,1,0,link->remoteCommStream>>>(nsrcranks,srcranks_d,src,srcdisp_d,dst,dstdisp_d,link->unitbytes);
424     cerr = cudaGetLastError();CHKERRCUDA(cerr);
425   }
426 
427   /* Get data from locally accessible PEs */
428   if (nLocallyAccessible) {
429     for (int i=0; i<nsrcranks; i++) {
430       int pe = srcranks_h[i];
431       if (nvshmem_ptr(src,pe)) {
432         size_t nelems = (dstdisp_h[i+1]-dstdisp_h[i])*link->unitbytes;
433         nvshmemx_getmem_nbi_on_stream(dst+(dstdisp_h[i]-dstdisp_h[0])*link->unitbytes,src+srcdisp_h[i]*link->unitbytes,nelems,pe,link->remoteCommStream);
434       }
435     }
436   }
437   PetscFunctionReturn(0);
438 }
439 
440 /* Finish the communication (can be done before Unpack)
441    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
442 */
443 PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
444 {
445   PetscErrorCode    ierr;
446   cudaError_t       cerr;
447   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
448   uint64_t          *srcsig;
449   PetscInt          nsrcranks,*srcsigdisp;
450   PetscMPIInt       *srcranks;
451 
452   PetscFunctionBegin;
453   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
454     nsrcranks   = sf->nRemoteRootRanks;
455     srcsig      = link->rootSendSig;     /* I want to set their root signal */
456     srcsigdisp  = sf->rootsigdisp_d;     /* offset of each root signal */
457     srcranks    = sf->ranks_d;           /* ranks of the n root ranks */
458   } else { /* LEAF2ROOT, root ranks are getting data */
459     nsrcranks   = bas->nRemoteLeafRanks;
460     srcsig      = link->leafSendSig;
461     srcsigdisp  = bas->leafsigdisp_d;
462     srcranks    = bas->iranks_d;
463   }
464 
465   if (nsrcranks) {
466     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
467     cerr = cudaGetLastError();CHKERRCUDA(cerr);
468     NvshmemSendSignals<<<(nsrcranks+511)/512,512,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp,srcranks,0); /* set signals to 0 */
469     cerr = cudaGetLastError();CHKERRCUDA(cerr);
470   }
471   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
472   PetscFunctionReturn(0);
473 }
474 
475 /* ===========================================================================================================
476 
477    A set of routines to support sender initiated communication using the put-based method (the default)
478 
479     The putting protocol is:
480 
481     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
482     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
483     is in nvshmem space.
484 
485     Sender:                                 |  Receiver:
486                                             |
487   1.  Pack data into sbuf                   |
488   2.  Wait ssig be 0, then set it to 1      |
489   3.  Put data to remote stand-alone rbuf   |
490   4.  Fence // make sure 5 happens after 3  |
491   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
492                                             |   2. Unpack data from local rbuf
493                                             |   3. Put 0 to sender's ssig
494    ===========================================================================================================*/
495 
496 /* n thread blocks. Each takes in charge one remote rank */
497 __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,char *dst,PetscInt *dstdisp,const char *src,PetscInt *srcdisp,uint64_t *srcsig,PetscInt unitbytes)
498 {
499   int               bid = blockIdx.x;
500   PetscMPIInt       pe  = dstranks[bid];
501 
502   if (!nvshmem_ptr(dst,pe)) {
503     PetscInt nelems = (srcdisp[bid+1]-srcdisp[bid])*unitbytes;
504     nvshmem_uint64_wait_until(srcsig+bid,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
505     srcsig[bid] = 1;
506     nvshmem_putmem_nbi(dst+dstdisp[bid]*unitbytes,src+(srcdisp[bid]-srcdisp[0])*unitbytes,nelems,pe);
507   }
508 }
509 
510 /* one-thread kernel, which takes in charge all locally accesible */
511 __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *srcsig,const char *dst)
512 {
513   for (int i=0; i<ndstranks; i++) {
514     int pe = dstranks[i];
515     if (nvshmem_ptr(dst,pe)) {
516       nvshmem_uint64_wait_until(srcsig+i,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
517       srcsig[i] = 1;
518     }
519   }
520 }
521 
522 /* Put data in the given direction  */
523 PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
524 {
525   PetscErrorCode    ierr;
526   cudaError_t       cerr;
527   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
528   PetscInt          ndstranks,nLocallyAccessible = 0;
529   char              *src,*dst;
530   PetscInt          *srcdisp_h,*dstdisp_h;
531   PetscInt          *srcdisp_d,*dstdisp_d;
532   PetscMPIInt       *dstranks_h;
533   PetscMPIInt       *dstranks_d;
534   uint64_t          *srcsig;
535 
536   PetscFunctionBegin;
537   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
538   if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf  */
539     ndstranks    = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */
540     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
541     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
542 
543     srcdisp_h    = bas->ioffset+bas->ndiranks;  /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
544     srcdisp_d    = bas->ioffset_d;
545     srcsig       = link->rootSendSig;
546 
547     dstdisp_h    = bas->leafbufdisp;            /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
548     dstdisp_d    = bas->leafbufdisp_d;
549     dstranks_h   = bas->iranks+bas->ndiranks;   /* remote leaf ranks */
550     dstranks_d   = bas->iranks_d;
551   } else { /* put data in leafbuf to rootbuf */
552     ndstranks    = sf->nRemoteRootRanks;
553     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
554     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
555 
556     srcdisp_h    = sf->roffset+sf->ndranks; /* offsets of leafbuf */
557     srcdisp_d    = sf->roffset_d;
558     srcsig       = link->leafSendSig;
559 
560     dstdisp_h    = sf->rootbufdisp;         /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
561     dstdisp_d    = sf->rootbufdisp_d;
562     dstranks_h   = sf->ranks+sf->ndranks;   /* remote root ranks */
563     dstranks_d   = sf->ranks_d;
564   }
565 
566   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
567 
568   /* Count number of locally accessible neighbors, which should be a small number */
569   for (int i=0; i<ndstranks; i++) {if (nvshmem_ptr(dst,dstranks_h[i])) nLocallyAccessible++;}
570 
571   /* For remotely accessible PEs, send data to them in one kernel call */
572   if (nLocallyAccessible < ndstranks) {
573     WaitAndPutDataToRemotelyAccessible<<<ndstranks,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,dst,dstdisp_d,src,srcdisp_d,srcsig,link->unitbytes);
574     cerr = cudaGetLastError();CHKERRCUDA(cerr);
575   }
576 
577   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
578   if (nLocallyAccessible) {
579     WaitSignalsFromLocallyAccessible<<<1,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,srcsig,dst);
580     for (int i=0; i<ndstranks; i++) {
581       int pe = dstranks_h[i];
582       if (nvshmem_ptr(dst,pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
583         size_t nelems = (srcdisp_h[i+1]-srcdisp_h[i])*link->unitbytes;
584          /* Initiate the nonblocking communication */
585         nvshmemx_putmem_nbi_on_stream(dst+dstdisp_h[i]*link->unitbytes,src+(srcdisp_h[i]-srcdisp_h[0])*link->unitbytes,nelems,pe,link->remoteCommStream);
586       }
587     }
588   }
589 
590   if (nLocallyAccessible) {
591     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
592   }
593   PetscFunctionReturn(0);
594 }
595 
596 /* A one-thread kernel. The thread takes in charge all remote PEs */
597 __global__ static void PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *dstsig,PetscInt *dstsigdisp)
598 {
599   /* TODO: Shall we finished the non-blocking remote puts? */
600 
601   /* 1. Send a signal to each dst rank */
602 
603   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
604      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
605   */
606   for (int i=0; i<ndstranks; i++) {nvshmemx_uint64_signal(dstsig+dstsigdisp[i],1,dstranks[i]);} /* set sig to 1 */
607 
608   /* 2. Wait for signals from src ranks (if any) */
609   if (nsrcranks) {
610     nvshmem_uint64_wait_until_all(dstsig,nsrcranks,NULL/*no mask*/,NVSHMEM_CMP_EQ,1); /* wait sigs to be 1, then set them to 0 */
611     for (int i=0; i<nsrcranks; i++) dstsig[i] = 0;
612   }
613 }
614 
615 /* Finish the communication -- A receiver waits until it can access its receive buffer */
616 PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
617 {
618   PetscErrorCode    ierr;
619   cudaError_t       cerr;
620   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
621   PetscMPIInt       *dstranks;
622   uint64_t          *dstsig;
623   PetscInt          nsrcranks,ndstranks,*dstsigdisp;
624 
625   PetscFunctionBegin;
626   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
627     nsrcranks    = sf->nRemoteRootRanks;
628 
629     ndstranks    = bas->nRemoteLeafRanks;
630     dstranks     = bas->iranks_d;       /* leaf ranks */
631     dstsig       = link->leafRecvSig;   /* I will set my leaf ranks's RecvSig */
632     dstsigdisp   = bas->leafsigdisp_d;  /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
633   } else { /* LEAF2ROOT */
634     nsrcranks    = bas->nRemoteLeafRanks;
635 
636     ndstranks    = sf->nRemoteRootRanks;
637     dstranks     = sf->ranks_d;
638     dstsig       = link->rootRecvSig;
639     dstsigdisp   = sf->rootsigdisp_d;
640   }
641 
642   if (nsrcranks || ndstranks) {
643     PutDataEnd<<<1,1,0,link->remoteCommStream>>>(nsrcranks,ndstranks,dstranks,dstsig,dstsigdisp);
644     cerr = cudaGetLastError();CHKERRCUDA(cerr);
645   }
646   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
647   PetscFunctionReturn(0);
648 }
649 
650 /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
651 PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
652 {
653   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
654   uint64_t          *srcsig;
655   PetscInt          nsrcranks,*srcsigdisp_d;
656   PetscMPIInt       *srcranks_d;
657 
658   PetscFunctionBegin;
659   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
660     nsrcranks    = sf->nRemoteRootRanks;
661     srcsig       = link->rootSendSig;      /* I want to set their send signals */
662     srcsigdisp_d = sf->rootsigdisp_d;      /* offset of each root signal */
663     srcranks_d   = sf->ranks_d;            /* ranks of the n root ranks */
664   } else { /* LEAF2ROOT */
665     nsrcranks    = bas->nRemoteLeafRanks;
666     srcsig       = link->leafSendSig;
667     srcsigdisp_d = bas->leafsigdisp_d;
668     srcranks_d   = bas->iranks_d;
669   }
670 
671   if (nsrcranks) {
672     NvshmemSendSignals<<<(nsrcranks+255)/256,256,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp_d,srcranks_d,0); /* Set remote signals to 0 */
673     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
674   }
675   PetscFunctionReturn(0);
676 }
677 
678 /* Destructor when the link uses nvshmem for communication */
679 static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)
680 {
681   PetscErrorCode    ierr;
682   cudaError_t       cerr;
683 
684   PetscFunctionBegin;
685   cerr = cudaEventDestroy(link->dataReady);CHKERRCUDA(cerr);
686   cerr = cudaEventDestroy(link->endRemoteComm);CHKERRCUDA(cerr);
687   cerr = cudaStreamDestroy(link->remoteCommStream);CHKERRCUDA(cerr);
688 
689   /* nvshmem does not need buffers on host, which should be NULL */
690   ierr = PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
691   ierr = PetscNvshmemFree(link->leafSendSig);CHKERRQ(ierr);
692   ierr = PetscNvshmemFree(link->leafRecvSig);CHKERRQ(ierr);
693   ierr = PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
694   ierr = PetscNvshmemFree(link->rootSendSig);CHKERRQ(ierr);
695   ierr = PetscNvshmemFree(link->rootRecvSig);CHKERRQ(ierr);
696   PetscFunctionReturn(0);
697 }
698 
699 PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink *mylink)
700 {
701   PetscErrorCode    ierr;
702   cudaError_t       cerr;
703   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
704   PetscSFLink       *p,link;
705   PetscBool         match,rootdirect[2],leafdirect[2];
706   int               greatestPriority;
707 
708   PetscFunctionBegin;
709   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
710      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM.
711   */
712   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
713     if (sf->use_nvshmem_get) {
714       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
715       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
716     } else {
717       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
718       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;  /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
719     }
720   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
721     if (sf->use_nvshmem_get) {
722       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
723       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
724     } else {
725       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
726       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
727     }
728   } else { /* PETSCSF_FETCH */
729     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
730     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
731   }
732 
733   /* Look for free nvshmem links in cache */
734   for (p=&bas->avail; (link=*p); p=&link->next) {
735     if (link->use_nvshmem) {
736       ierr = MPIPetsc_Type_compare(unit,link->unit,&match);CHKERRQ(ierr);
737       if (match) {
738         *p = link->next; /* Remove from available list */
739         goto found;
740       }
741     }
742   }
743   ierr = PetscNew(&link);CHKERRQ(ierr);
744   ierr = PetscSFLinkSetUp_Host(sf,link,unit);CHKERRQ(ierr); /* Compute link->unitbytes, dup link->unit etc. */
745   if (sf->backend == PETSCSF_BACKEND_CUDA) {ierr = PetscSFLinkSetUp_CUDA(sf,link,unit);CHKERRQ(ierr);} /* Setup pack routines, streams etc */
746  #if defined(PETSC_HAVE_KOKKOS)
747   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) {ierr = PetscSFLinkSetUp_Kokkos(sf,link,unit);CHKERRQ(ierr);}
748  #endif
749 
750   link->rootdirect[PETSCSF_LOCAL]  = PETSC_TRUE; /* For the local part we directly use root/leafdata */
751   link->leafdirect[PETSCSF_LOCAL]  = PETSC_TRUE;
752 
753   /* Init signals to zero */
754   if (!link->rootSendSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootSendSig);CHKERRQ(ierr);}
755   if (!link->rootRecvSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootRecvSig);CHKERRQ(ierr);}
756   if (!link->leafSendSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafSendSig);CHKERRQ(ierr);}
757   if (!link->leafRecvSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafRecvSig);CHKERRQ(ierr);}
758 
759   link->use_nvshmem                = PETSC_TRUE;
760   link->rootmtype                  = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
761   link->leafmtype                  = PETSC_MEMTYPE_DEVICE;
762   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
763   link->Destroy                    = PetscSFLinkDestroy_NVSHMEM;
764   if (sf->use_nvshmem_get) { /* get-based protocol */
765     link->PrePack                  = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
766     link->StartCommunication       = PetscSFLinkGetDataBegin_NVSHMEM;
767     link->FinishCommunication      = PetscSFLinkGetDataEnd_NVSHMEM;
768   } else { /* put-based protocol */
769     link->StartCommunication       = PetscSFLinkPutDataBegin_NVSHMEM;
770     link->FinishCommunication      = PetscSFLinkPutDataEnd_NVSHMEM;
771     link->PostUnpack               = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
772   }
773 
774   cerr = cudaDeviceGetStreamPriorityRange(NULL,&greatestPriority);CHKERRCUDA(cerr);
775   cerr = cudaStreamCreateWithPriority(&link->remoteCommStream,cudaStreamNonBlocking,greatestPriority);CHKERRCUDA(cerr);
776 
777   cerr = cudaEventCreateWithFlags(&link->dataReady,cudaEventDisableTiming);CHKERRCUDA(cerr);
778   cerr = cudaEventCreateWithFlags(&link->endRemoteComm,cudaEventDisableTiming);CHKERRCUDA(cerr);
779 
780 found:
781   if (rootdirect[PETSCSF_REMOTE]) {
782     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)rootdata + bas->rootstart[PETSCSF_REMOTE]*link->unitbytes;
783   } else {
784     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
785       ierr = PetscNvshmemMalloc(bas->rootbuflen_rmax*link->unitbytes,(void**)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
786     }
787     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
788   }
789 
790   if (leafdirect[PETSCSF_REMOTE]) {
791     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)leafdata + sf->leafstart[PETSCSF_REMOTE]*link->unitbytes;
792   } else {
793     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
794       ierr = PetscNvshmemMalloc(sf->leafbuflen_rmax*link->unitbytes,(void**)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
795     }
796     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
797   }
798 
799   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
800   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
801   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
802   link->leafdata                   = leafdata;
803   link->next                       = bas->inuse;
804   bas->inuse                       = link;
805   *mylink                          = link;
806   PetscFunctionReturn(0);
807 }
808 
809 #if defined(PETSC_USE_REAL_SINGLE)
810 PetscErrorCode PetscNvshmemSum(PetscInt count,float *dst,const float *src)
811 {
812   PetscErrorCode    ierr;
813   PetscMPIInt       num; /* Assume nvshmem's int is MPI's int */
814 
815   PetscFunctionBegin;
816   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
817   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
818   PetscFunctionReturn(0);
819 }
820 
821 PetscErrorCode PetscNvshmemMax(PetscInt count,float *dst,const float *src)
822 {
823   PetscErrorCode    ierr;
824   PetscMPIInt       num;
825 
826   PetscFunctionBegin;
827   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
828   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
829   PetscFunctionReturn(0);
830 }
831 #elif defined(PETSC_USE_REAL_DOUBLE)
832 PetscErrorCode PetscNvshmemSum(PetscInt count,double *dst,const double *src)
833 {
834   PetscErrorCode    ierr;
835   PetscMPIInt       num;
836 
837   PetscFunctionBegin;
838   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
839   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
840   PetscFunctionReturn(0);
841 }
842 
843 PetscErrorCode PetscNvshmemMax(PetscInt count,double *dst,const double *src)
844 {
845   PetscErrorCode    ierr;
846   PetscMPIInt       num;
847 
848   PetscFunctionBegin;
849   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
850   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
851   PetscFunctionReturn(0);
852 }
853 #endif
854 
855