xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision f5ff9c666c0d37e8a7ec3aa2f8e2aa9e44449bdb)
1 #include <petsc/private/cudavecimpl.h>
2 #include <petsccublas.h>
3 #include <../src/vec/is/sf/impls/basic/sfpack.h>
4 #include <mpi.h>
5 #include <nvshmem.h>
6 #include <nvshmemx.h>
7 
8 PetscErrorCode PetscNvshmemInitializeCheck(void)
9 {
10   PetscErrorCode   ierr;
11 
12   PetscFunctionBegin;
13   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
14     nvshmemx_init_attr_t attr;
15     attr.mpi_comm = &PETSC_COMM_WORLD;
16     ierr = PetscCUDAInitializeCheck();CHKERRQ(ierr);
17     ierr = nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM,&attr);CHKERRQ(ierr);
18     PetscNvshmemInitialized = PETSC_TRUE;
19     PetscBeganNvshmem       = PETSC_TRUE;
20   }
21   PetscFunctionReturn(0);
22 }
23 
24 PetscErrorCode PetscNvshmemMalloc(size_t size, void** ptr)
25 {
26   PetscErrorCode ierr;
27 
28   PetscFunctionBegin;
29   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
30   *ptr = nvshmem_malloc(size);
31   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_malloc() failed to allocate %zu bytes",size);
32   PetscFunctionReturn(0);
33 }
34 
35 PetscErrorCode PetscNvshmemCalloc(size_t size, void**ptr)
36 {
37   PetscErrorCode ierr;
38 
39   PetscFunctionBegin;
40   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
41   *ptr = nvshmem_calloc(size,1);
42   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_calloc() failed to allocate %zu bytes",size);
43   PetscFunctionReturn(0);
44 }
45 
46 PetscErrorCode PetscNvshmemFree_Private(void* ptr)
47 {
48   PetscFunctionBegin;
49   nvshmem_free(ptr);
50   PetscFunctionReturn(0);
51 }
52 
53 PetscErrorCode PetscNvshmemFinalize(void)
54 {
55   PetscFunctionBegin;
56   nvshmem_finalize();
57   PetscFunctionReturn(0);
58 }
59 
60 /* Free nvshmem related fields in the SF */
61 PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
62 {
63   PetscErrorCode    ierr;
64   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
65 
66   PetscFunctionBegin;
67   ierr = PetscFree2(bas->leafsigdisp,bas->leafbufdisp);CHKERRQ(ierr);
68   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafbufdisp_d);CHKERRQ(ierr);
69   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafsigdisp_d);CHKERRQ(ierr);
70   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->iranks_d);CHKERRQ(ierr);
71   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->ioffset_d);CHKERRQ(ierr);
72 
73   ierr = PetscFree2(sf->rootsigdisp,sf->rootbufdisp);CHKERRQ(ierr);
74   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootbufdisp_d);CHKERRQ(ierr);
75   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootsigdisp_d);CHKERRQ(ierr);
76   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->ranks_d);CHKERRQ(ierr);
77   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->roffset_d);CHKERRQ(ierr);
78   PetscFunctionReturn(0);
79 }
80 
81 /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
82 static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
83 {
84   PetscErrorCode ierr;
85   cudaError_t    cerr;
86   PetscSF_Basic  *bas = (PetscSF_Basic*)sf->data;
87   PetscInt       i,nRemoteRootRanks,nRemoteLeafRanks;
88   PetscMPIInt    tag;
89   MPI_Comm       comm;
90   MPI_Request    *rootreqs,*leafreqs;
91   PetscInt       tmp,stmp[4],rtmp[4]; /* tmps for send/recv buffers */
92 
93   PetscFunctionBegin;
94   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
95   ierr = PetscObjectGetNewTag((PetscObject)sf,&tag);CHKERRQ(ierr);
96 
97   nRemoteRootRanks      = sf->nranks-sf->ndranks;
98   nRemoteLeafRanks      = bas->niranks-bas->ndiranks;
99   sf->nRemoteRootRanks  = nRemoteRootRanks;
100   bas->nRemoteLeafRanks = nRemoteLeafRanks;
101 
102   ierr = PetscMalloc2(nRemoteLeafRanks,&rootreqs,nRemoteRootRanks,&leafreqs);CHKERRQ(ierr);
103 
104   stmp[0] = nRemoteRootRanks;
105   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
106   stmp[2] = nRemoteLeafRanks;
107   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
108 
109   ierr = MPIU_Allreduce(stmp,rtmp,4,MPIU_INT,MPI_MAX,comm);CHKERRMPI(ierr);
110 
111   sf->nRemoteRootRanksMax   = rtmp[0];
112   sf->leafbuflen_rmax       = rtmp[1];
113   bas->nRemoteLeafRanksMax  = rtmp[2];
114   bas->rootbuflen_rmax      = rtmp[3];
115 
116   /* Total four rounds of MPI communications to set up the nvshmem fields */
117 
118   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
119   ierr = PetscMalloc2(nRemoteRootRanks,&sf->rootsigdisp,nRemoteRootRanks,&sf->rootbufdisp);CHKERRQ(ierr);
120   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootsigdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
121   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);} /* Roots send. Note i changes, so we use MPI_Send. */
122   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
123 
124   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootbufdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
125   for (i=0; i<nRemoteLeafRanks; i++) {
126     tmp  = bas->ioffset[i+bas->ndiranks] - bas->ioffset[bas->ndiranks];
127     ierr = MPI_Send(&tmp,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);  /* Roots send. Note tmp changes, so we use MPI_Send. */
128   }
129   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
130 
131   cerr = cudaMalloc((void**)&sf->rootbufdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
132   cerr = cudaMalloc((void**)&sf->rootsigdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
133   cerr = cudaMalloc((void**)&sf->ranks_d,nRemoteRootRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
134   cerr = cudaMalloc((void**)&sf->roffset_d,(nRemoteRootRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
135 
136   cerr = cudaMemcpyAsync(sf->rootbufdisp_d,sf->rootbufdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
137   cerr = cudaMemcpyAsync(sf->rootsigdisp_d,sf->rootsigdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
138   cerr = cudaMemcpyAsync(sf->ranks_d,sf->ranks+sf->ndranks,nRemoteRootRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
139   cerr = cudaMemcpyAsync(sf->roffset_d,sf->roffset+sf->ndranks,(nRemoteRootRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
140 
141   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
142   ierr = PetscMalloc2(nRemoteLeafRanks,&bas->leafsigdisp,nRemoteLeafRanks,&bas->leafbufdisp);CHKERRQ(ierr);
143   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafsigdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
144   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);}
145   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
146 
147   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafbufdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
148   for (i=0; i<nRemoteRootRanks; i++) {
149     tmp  = sf->roffset[i+sf->ndranks] - sf->roffset[sf->ndranks];
150     ierr = MPI_Send(&tmp,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);
151   }
152   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
153 
154   cerr = cudaMalloc((void**)&bas->leafbufdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
155   cerr = cudaMalloc((void**)&bas->leafsigdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
156   cerr = cudaMalloc((void**)&bas->iranks_d,nRemoteLeafRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
157   cerr = cudaMalloc((void**)&bas->ioffset_d,(nRemoteLeafRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
158 
159   cerr = cudaMemcpyAsync(bas->leafbufdisp_d,bas->leafbufdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
160   cerr = cudaMemcpyAsync(bas->leafsigdisp_d,bas->leafsigdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
161   cerr = cudaMemcpyAsync(bas->iranks_d,bas->iranks+bas->ndiranks,nRemoteLeafRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
162   cerr = cudaMemcpyAsync(bas->ioffset_d,bas->ioffset+bas->ndiranks,(nRemoteLeafRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
163 
164   ierr = PetscFree2(rootreqs,leafreqs);CHKERRQ(ierr);
165   PetscFunctionReturn(0);
166 }
167 
168 PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,PetscBool *use_nvshmem)
169 {
170   PetscErrorCode   ierr;
171   MPI_Comm         comm;
172   PetscBool        isBasic;
173   PetscMPIInt      result = MPI_UNEQUAL;
174 
175   PetscFunctionBegin;
176   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
177   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
178      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
179   */
180   sf->checked_nvshmem_eligibility = PETSC_TRUE;
181   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
182     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
183     ierr = PetscObjectTypeCompare((PetscObject)sf,PETSCSFBASIC,&isBasic);CHKERRQ(ierr);
184     if (isBasic) {ierr = MPI_Comm_compare(PETSC_COMM_WORLD,comm,&result);CHKERRMPI(ierr);}
185     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
186 
187     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
188        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
189        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
190     */
191     if (sf->use_nvshmem) {
192       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
193       ierr = MPI_Allreduce(MPI_IN_PLACE,&hasNullRank,1,MPIU_INT,MPI_LOR,comm);CHKERRMPI(ierr);
194       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
195     }
196     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
197   }
198 
199   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
200   if (sf->use_nvshmem) {
201     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
202     PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
203    #if defined(PETSC_USE_DEBUG)  /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
204     ierr = MPI_Allreduce(&oneCuda,&allCuda,1,MPIU_INT,MPI_LAND,comm);CHKERRMPI(ierr);
205     if (allCuda != oneCuda) SETERRQ(comm,PETSC_ERR_SUP,"root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
206    #endif
207     if (allCuda) {
208       ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
209       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
210         ierr = PetscSFSetUp_Basic_NVSHMEM(sf);CHKERRQ(ierr);
211         sf->setup_nvshmem = PETSC_TRUE;
212       }
213       *use_nvshmem = PETSC_TRUE;
214     } else {
215       *use_nvshmem = PETSC_FALSE;
216     }
217   } else {
218     *use_nvshmem = PETSC_FALSE;
219   }
220   PetscFunctionReturn(0);
221 }
222 
223 /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
224 static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
225 {
226   cudaError_t    cerr;
227   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
228   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
229 
230   PetscFunctionBegin;
231   if (buflen) {
232     cerr = cudaEventRecord(link->dataReady,link->stream);CHKERRCUDA(cerr);
233     cerr = cudaStreamWaitEvent(link->remoteCommStream,link->dataReady,0);CHKERRCUDA(cerr);
234   }
235   PetscFunctionReturn(0);
236 }
237 
238 /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
239 static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
240 {
241   cudaError_t    cerr;
242   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
243   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
244 
245   PetscFunctionBegin;
246   /* If unpack to non-null device buffer, build the endRemoteComm dependance */
247   if (buflen) {
248     cerr = cudaEventRecord(link->endRemoteComm,link->remoteCommStream);CHKERRCUDA(cerr);
249     cerr = cudaStreamWaitEvent(link->stream,link->endRemoteComm,0);CHKERRCUDA(cerr);
250   }
251   PetscFunctionReturn(0);
252 }
253 
254 /* Send/Put signals to remote ranks
255 
256  Input parameters:
257   + n        - Number of remote ranks
258   . sig      - Signal address in symmetric heap
259   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
260   . ranks    - remote ranks
261   - newval   - Set signals to this value
262 */
263 __global__ static void NvshmemSendSignals(PetscInt n,uint64_t *sig,PetscInt *sigdisp,PetscMPIInt *ranks,uint64_t newval)
264 {
265   int i = blockIdx.x*blockDim.x + threadIdx.x;
266 
267   /* Each thread puts one remote signal */
268   if (i < n) nvshmemx_uint64_signal(sig+sigdisp[i],newval,ranks[i]);
269 }
270 
271 /* Wait until local signals equal to the expected value and then set them to a new value
272 
273  Input parameters:
274   + n        - Number of signals
275   . sig      - Local signal address
276   . expval   - expected value
277   - newval   - Set signals to this new value
278 */
279 __global__ static void NvshmemWaitSignals(PetscInt n,uint64_t *sig,uint64_t expval,uint64_t newval)
280 {
281 #if 0
282   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
283   int i = blockIdx.x*blockDim.x + threadIdx.x;
284   if (i < n) {
285     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
286     sig[i] = newval;
287   }
288 #else
289   nvshmem_uint64_wait_until_all(sig,n,NULL/*no mask*/,NVSHMEM_CMP_EQ,expval);
290   for (int i=0; i<n; i++) sig[i] = newval;
291 #endif
292 }
293 
294 /* ===========================================================================================================
295 
296    A set of routines to support receiver initiated communication using the get method
297 
298     The getting protocol is:
299 
300     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
301     All signal variables have an initial value 0.
302 
303     Sender:                                 |  Receiver:
304   1.  Wait ssig be 0, then set it to 1
305   2.  Pack data into stand alone sbuf       |
306   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
307                                             |   2. Get data from remote sbuf to local rbuf
308                                             |   3. Put 1 to sender's ssig
309                                             |   4. Unpack data from local rbuf
310    ===========================================================================================================*/
311 /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
312    Sender waits for signals (from receivers) indicating receivers have finished getting data
313 */
314 PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
315 {
316   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
317   uint64_t          *sig;
318   PetscInt          n;
319 
320   PetscFunctionBegin;
321   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
322     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
323     n   = bas->nRemoteLeafRanks;
324   } else { /* LEAF2ROOT */
325     sig = link->leafSendSig;
326     n   = sf->nRemoteRootRanks;
327   }
328 
329   if (n) {
330     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(n,sig,0,1); /* wait the signals to be 0, then set them to 1 */
331     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
332   }
333   PetscFunctionReturn(0);
334 }
335 
336 /* n thread blocks. Each takes in charge one remote rank */
337 __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt *srcranks,const char *src,PetscInt *srcdisp,char *dst,PetscInt *dstdisp,PetscInt unitbytes)
338 {
339   int               bid = blockIdx.x;
340   PetscMPIInt       pe  = srcranks[bid];
341 
342   if (!nvshmem_ptr(src,pe)) {
343     PetscInt nelems = (dstdisp[bid+1]-dstdisp[bid])*unitbytes;
344     nvshmem_getmem_nbi(dst+(dstdisp[bid]-dstdisp[0])*unitbytes,src+srcdisp[bid]*unitbytes,nelems,pe);
345   }
346 }
347 
348 /* Start communication -- Get data in the given direction */
349 PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
350 {
351   PetscErrorCode    ierr;
352   cudaError_t       cerr;
353   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
354 
355   PetscInt          nsrcranks,ndstranks,nLocallyAccessible = 0;
356 
357   char              *src,*dst;
358   PetscInt          *srcdisp_h,*dstdisp_h;
359   PetscInt          *srcdisp_d,*dstdisp_d;
360   PetscMPIInt       *srcranks_h;
361   PetscMPIInt       *srcranks_d,*dstranks_d;
362   uint64_t          *dstsig;
363   PetscInt          *dstsigdisp_d;
364 
365   PetscFunctionBegin;
366   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
367   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
368     nsrcranks    = sf->nRemoteRootRanks;
369     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
370 
371     srcdisp_h    = sf->rootbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
372     srcdisp_d    = sf->rootbufdisp_d;
373     srcranks_h   = sf->ranks+sf->ndranks; /* my (remote) root ranks */
374     srcranks_d   = sf->ranks_d;
375 
376     ndstranks    = bas->nRemoteLeafRanks;
377     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
378 
379     dstdisp_h    = sf->roffset+sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
380     dstdisp_d    = sf->roffset_d;
381     dstranks_d   = bas->iranks_d; /* my (remote) leaf ranks */
382 
383     dstsig       = link->leafRecvSig;
384     dstsigdisp_d = bas->leafsigdisp_d;
385   } else { /* src is leaf, dst is root; we will move data from src to dst */
386     nsrcranks    = bas->nRemoteLeafRanks;
387     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
388 
389     srcdisp_h    = bas->leafbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
390     srcdisp_d    = bas->leafbufdisp_d;
391     srcranks_h   = bas->iranks+bas->ndiranks; /* my (remote) root ranks */
392     srcranks_d   = bas->iranks_d;
393 
394     ndstranks    = sf->nRemoteRootRanks;
395     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
396 
397     dstdisp_h    = bas->ioffset+bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
398     dstdisp_d    = bas->ioffset_d;
399     dstranks_d   = sf->ranks_d; /* my (remote) root ranks */
400 
401     dstsig       = link->rootRecvSig;
402     dstsigdisp_d = sf->rootsigdisp_d;
403   }
404 
405   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
406   if (ndstranks) {
407     NvshmemSendSignals<<<(ndstranks+255)/256,256,0,link->remoteCommStream>>>(ndstranks,dstsig,dstsigdisp_d,dstranks_d,1); /* set signals to 1 */
408     cerr = cudaGetLastError();CHKERRCUDA(cerr);
409   }
410 
411   /* dst waits for signals (permissions) from src ranks to start getting data */
412   if (nsrcranks) {
413     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(nsrcranks,dstsig,1,0); /* wait the signals to be 1, then set them to 0 */
414     cerr = cudaGetLastError();CHKERRCUDA(cerr);
415   }
416 
417   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
418 
419   /* Count number of locally accessible src ranks, which should be a small number */
420   for (int i=0; i<nsrcranks; i++) {if (nvshmem_ptr(src,srcranks_h[i])) nLocallyAccessible++;}
421 
422   /* Get data from remotely accessible PEs */
423   if (nLocallyAccessible < nsrcranks) {
424     GetDataFromRemotelyAccessible<<<nsrcranks,1,0,link->remoteCommStream>>>(nsrcranks,srcranks_d,src,srcdisp_d,dst,dstdisp_d,link->unitbytes);
425     cerr = cudaGetLastError();CHKERRCUDA(cerr);
426   }
427 
428   /* Get data from locally accessible PEs */
429   if (nLocallyAccessible) {
430     for (int i=0; i<nsrcranks; i++) {
431       int pe = srcranks_h[i];
432       if (nvshmem_ptr(src,pe)) {
433         size_t nelems = (dstdisp_h[i+1]-dstdisp_h[i])*link->unitbytes;
434         nvshmemx_getmem_nbi_on_stream(dst+(dstdisp_h[i]-dstdisp_h[0])*link->unitbytes,src+srcdisp_h[i]*link->unitbytes,nelems,pe,link->remoteCommStream);
435       }
436     }
437   }
438   PetscFunctionReturn(0);
439 }
440 
441 /* Finish the communication (can be done before Unpack)
442    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
443 */
444 PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
445 {
446   PetscErrorCode    ierr;
447   cudaError_t       cerr;
448   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
449   uint64_t          *srcsig;
450   PetscInt          nsrcranks,*srcsigdisp;
451   PetscMPIInt       *srcranks;
452 
453   PetscFunctionBegin;
454   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
455     nsrcranks   = sf->nRemoteRootRanks;
456     srcsig      = link->rootSendSig;     /* I want to set their root signal */
457     srcsigdisp  = sf->rootsigdisp_d;     /* offset of each root signal */
458     srcranks    = sf->ranks_d;           /* ranks of the n root ranks */
459   } else { /* LEAF2ROOT, root ranks are getting data */
460     nsrcranks   = bas->nRemoteLeafRanks;
461     srcsig      = link->leafSendSig;
462     srcsigdisp  = bas->leafsigdisp_d;
463     srcranks    = bas->iranks_d;
464   }
465 
466   if (nsrcranks) {
467     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
468     cerr = cudaGetLastError();CHKERRCUDA(cerr);
469     NvshmemSendSignals<<<(nsrcranks+511)/512,512,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp,srcranks,0); /* set signals to 0 */
470     cerr = cudaGetLastError();CHKERRCUDA(cerr);
471   }
472   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
473   PetscFunctionReturn(0);
474 }
475 
476 /* ===========================================================================================================
477 
478    A set of routines to support sender initiated communication using the put-based method (the default)
479 
480     The putting protocol is:
481 
482     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
483     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
484     is in nvshmem space.
485 
486     Sender:                                 |  Receiver:
487                                             |
488   1.  Pack data into sbuf                   |
489   2.  Wait ssig be 0, then set it to 1      |
490   3.  Put data to remote stand-alone rbuf   |
491   4.  Fence // make sure 5 happens after 3  |
492   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
493                                             |   2. Unpack data from local rbuf
494                                             |   3. Put 0 to sender's ssig
495    ===========================================================================================================*/
496 
497 /* n thread blocks. Each takes in charge one remote rank */
498 __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,char *dst,PetscInt *dstdisp,const char *src,PetscInt *srcdisp,uint64_t *srcsig,PetscInt unitbytes)
499 {
500   int               bid = blockIdx.x;
501   PetscMPIInt       pe  = dstranks[bid];
502 
503   if (!nvshmem_ptr(dst,pe)) {
504     PetscInt nelems = (srcdisp[bid+1]-srcdisp[bid])*unitbytes;
505     nvshmem_uint64_wait_until(srcsig+bid,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
506     srcsig[bid] = 1;
507     nvshmem_putmem_nbi(dst+dstdisp[bid]*unitbytes,src+(srcdisp[bid]-srcdisp[0])*unitbytes,nelems,pe);
508   }
509 }
510 
511 /* one-thread kernel, which takes in charge all locally accesible */
512 __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *srcsig,const char *dst)
513 {
514   for (int i=0; i<ndstranks; i++) {
515     int pe = dstranks[i];
516     if (nvshmem_ptr(dst,pe)) {
517       nvshmem_uint64_wait_until(srcsig+i,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
518       srcsig[i] = 1;
519     }
520   }
521 }
522 
523 /* Put data in the given direction  */
524 PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
525 {
526   PetscErrorCode    ierr;
527   cudaError_t       cerr;
528   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
529   PetscInt          ndstranks,nLocallyAccessible = 0;
530   char              *src,*dst;
531   PetscInt          *srcdisp_h,*dstdisp_h;
532   PetscInt          *srcdisp_d,*dstdisp_d;
533   PetscMPIInt       *dstranks_h;
534   PetscMPIInt       *dstranks_d;
535   uint64_t          *srcsig;
536 
537   PetscFunctionBegin;
538   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
539   if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf  */
540     ndstranks    = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */
541     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
542     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
543 
544     srcdisp_h    = bas->ioffset+bas->ndiranks;  /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
545     srcdisp_d    = bas->ioffset_d;
546     srcsig       = link->rootSendSig;
547 
548     dstdisp_h    = bas->leafbufdisp;            /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
549     dstdisp_d    = bas->leafbufdisp_d;
550     dstranks_h   = bas->iranks+bas->ndiranks;   /* remote leaf ranks */
551     dstranks_d   = bas->iranks_d;
552   } else { /* put data in leafbuf to rootbuf */
553     ndstranks    = sf->nRemoteRootRanks;
554     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
555     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
556 
557     srcdisp_h    = sf->roffset+sf->ndranks; /* offsets of leafbuf */
558     srcdisp_d    = sf->roffset_d;
559     srcsig       = link->leafSendSig;
560 
561     dstdisp_h    = sf->rootbufdisp;         /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
562     dstdisp_d    = sf->rootbufdisp_d;
563     dstranks_h   = sf->ranks+sf->ndranks;   /* remote root ranks */
564     dstranks_d   = sf->ranks_d;
565   }
566 
567   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
568 
569   /* Count number of locally accessible neighbors, which should be a small number */
570   for (int i=0; i<ndstranks; i++) {if (nvshmem_ptr(dst,dstranks_h[i])) nLocallyAccessible++;}
571 
572   /* For remotely accessible PEs, send data to them in one kernel call */
573   if (nLocallyAccessible < ndstranks) {
574     WaitAndPutDataToRemotelyAccessible<<<ndstranks,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,dst,dstdisp_d,src,srcdisp_d,srcsig,link->unitbytes);
575     cerr = cudaGetLastError();CHKERRCUDA(cerr);
576   }
577 
578   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
579   if (nLocallyAccessible) {
580     WaitSignalsFromLocallyAccessible<<<1,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,srcsig,dst);
581     for (int i=0; i<ndstranks; i++) {
582       int pe = dstranks_h[i];
583       if (nvshmem_ptr(dst,pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
584         size_t nelems = (srcdisp_h[i+1]-srcdisp_h[i])*link->unitbytes;
585          /* Initiate the nonblocking communication */
586         nvshmemx_putmem_nbi_on_stream(dst+dstdisp_h[i]*link->unitbytes,src+(srcdisp_h[i]-srcdisp_h[0])*link->unitbytes,nelems,pe,link->remoteCommStream);
587       }
588     }
589   }
590 
591   if (nLocallyAccessible) {
592     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
593   }
594   PetscFunctionReturn(0);
595 }
596 
597 /* A one-thread kernel. The thread takes in charge all remote PEs */
598 __global__ static void PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *dstsig,PetscInt *dstsigdisp)
599 {
600   /* TODO: Shall we finished the non-blocking remote puts? */
601 
602   /* 1. Send a signal to each dst rank */
603 
604   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
605      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
606   */
607   for (int i=0; i<ndstranks; i++) {nvshmemx_uint64_signal(dstsig+dstsigdisp[i],1,dstranks[i]);} /* set sig to 1 */
608 
609   /* 2. Wait for signals from src ranks (if any) */
610   if (nsrcranks) {
611     nvshmem_uint64_wait_until_all(dstsig,nsrcranks,NULL/*no mask*/,NVSHMEM_CMP_EQ,1); /* wait sigs to be 1, then set them to 0 */
612     for (int i=0; i<nsrcranks; i++) dstsig[i] = 0;
613   }
614 }
615 
616 /* Finish the communication -- A receiver waits until it can access its receive buffer */
617 PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
618 {
619   PetscErrorCode    ierr;
620   cudaError_t       cerr;
621   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
622   PetscMPIInt       *dstranks;
623   uint64_t          *dstsig;
624   PetscInt          nsrcranks,ndstranks,*dstsigdisp;
625 
626   PetscFunctionBegin;
627   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
628     nsrcranks    = sf->nRemoteRootRanks;
629 
630     ndstranks    = bas->nRemoteLeafRanks;
631     dstranks     = bas->iranks_d;       /* leaf ranks */
632     dstsig       = link->leafRecvSig;   /* I will set my leaf ranks's RecvSig */
633     dstsigdisp   = bas->leafsigdisp_d;  /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
634   } else { /* LEAF2ROOT */
635     nsrcranks    = bas->nRemoteLeafRanks;
636 
637     ndstranks    = sf->nRemoteRootRanks;
638     dstranks     = sf->ranks_d;
639     dstsig       = link->rootRecvSig;
640     dstsigdisp   = sf->rootsigdisp_d;
641   }
642 
643   if (nsrcranks || ndstranks) {
644     PutDataEnd<<<1,1,0,link->remoteCommStream>>>(nsrcranks,ndstranks,dstranks,dstsig,dstsigdisp);
645     cerr = cudaGetLastError();CHKERRCUDA(cerr);
646   }
647   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
648   PetscFunctionReturn(0);
649 }
650 
651 /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
652 PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
653 {
654   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
655   uint64_t          *srcsig;
656   PetscInt          nsrcranks,*srcsigdisp_d;
657   PetscMPIInt       *srcranks_d;
658 
659   PetscFunctionBegin;
660   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
661     nsrcranks    = sf->nRemoteRootRanks;
662     srcsig       = link->rootSendSig;      /* I want to set their send signals */
663     srcsigdisp_d = sf->rootsigdisp_d;      /* offset of each root signal */
664     srcranks_d   = sf->ranks_d;            /* ranks of the n root ranks */
665   } else { /* LEAF2ROOT */
666     nsrcranks    = bas->nRemoteLeafRanks;
667     srcsig       = link->leafSendSig;
668     srcsigdisp_d = bas->leafsigdisp_d;
669     srcranks_d   = bas->iranks_d;
670   }
671 
672   if (nsrcranks) {
673     NvshmemSendSignals<<<(nsrcranks+255)/256,256,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp_d,srcranks_d,0); /* Set remote signals to 0 */
674     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
675   }
676   PetscFunctionReturn(0);
677 }
678 
679 /* Destructor when the link uses nvshmem for communication */
680 static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)
681 {
682   PetscErrorCode    ierr;
683   cudaError_t       cerr;
684 
685   PetscFunctionBegin;
686   cerr = cudaEventDestroy(link->dataReady);CHKERRCUDA(cerr);
687   cerr = cudaEventDestroy(link->endRemoteComm);CHKERRCUDA(cerr);
688   cerr = cudaStreamDestroy(link->remoteCommStream);CHKERRCUDA(cerr);
689 
690   /* nvshmem does not need buffers on host, which should be NULL */
691   ierr = PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
692   ierr = PetscNvshmemFree(link->leafSendSig);CHKERRQ(ierr);
693   ierr = PetscNvshmemFree(link->leafRecvSig);CHKERRQ(ierr);
694   ierr = PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
695   ierr = PetscNvshmemFree(link->rootSendSig);CHKERRQ(ierr);
696   ierr = PetscNvshmemFree(link->rootRecvSig);CHKERRQ(ierr);
697   PetscFunctionReturn(0);
698 }
699 
700 PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink *mylink)
701 {
702   PetscErrorCode    ierr;
703   cudaError_t       cerr;
704   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
705   PetscSFLink       *p,link;
706   PetscBool         match,rootdirect[2],leafdirect[2];
707   int               greatestPriority;
708 
709   PetscFunctionBegin;
710   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
711      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM.
712   */
713   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
714     if (sf->use_nvshmem_get) {
715       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
716       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
717     } else {
718       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
719       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;  /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
720     }
721   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
722     if (sf->use_nvshmem_get) {
723       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
724       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
725     } else {
726       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
727       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
728     }
729   } else { /* PETSCSF_FETCH */
730     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
731     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
732   }
733 
734   /* Look for free nvshmem links in cache */
735   for (p=&bas->avail; (link=*p); p=&link->next) {
736     if (link->use_nvshmem) {
737       ierr = MPIPetsc_Type_compare(unit,link->unit,&match);CHKERRQ(ierr);
738       if (match) {
739         *p = link->next; /* Remove from available list */
740         goto found;
741       }
742     }
743   }
744   ierr = PetscNew(&link);CHKERRQ(ierr);
745   ierr = PetscSFLinkSetUp_Host(sf,link,unit);CHKERRQ(ierr); /* Compute link->unitbytes, dup link->unit etc. */
746   if (sf->backend == PETSCSF_BACKEND_CUDA) {ierr = PetscSFLinkSetUp_CUDA(sf,link,unit);CHKERRQ(ierr);} /* Setup pack routines, streams etc */
747  #if defined(PETSC_HAVE_KOKKOS)
748   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) {ierr = PetscSFLinkSetUp_Kokkos(sf,link,unit);CHKERRQ(ierr);}
749  #endif
750 
751   link->rootdirect[PETSCSF_LOCAL]  = PETSC_TRUE; /* For the local part we directly use root/leafdata */
752   link->leafdirect[PETSCSF_LOCAL]  = PETSC_TRUE;
753 
754   /* Init signals to zero */
755   if (!link->rootSendSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootSendSig);CHKERRQ(ierr);}
756   if (!link->rootRecvSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootRecvSig);CHKERRQ(ierr);}
757   if (!link->leafSendSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafSendSig);CHKERRQ(ierr);}
758   if (!link->leafRecvSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafRecvSig);CHKERRQ(ierr);}
759 
760   link->use_nvshmem                = PETSC_TRUE;
761   link->rootmtype                  = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
762   link->leafmtype                  = PETSC_MEMTYPE_DEVICE;
763   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
764   link->Destroy                    = PetscSFLinkDestroy_NVSHMEM;
765   if (sf->use_nvshmem_get) { /* get-based protocol */
766     link->PrePack                  = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
767     link->StartCommunication       = PetscSFLinkGetDataBegin_NVSHMEM;
768     link->FinishCommunication      = PetscSFLinkGetDataEnd_NVSHMEM;
769   } else { /* put-based protocol */
770     link->StartCommunication       = PetscSFLinkPutDataBegin_NVSHMEM;
771     link->FinishCommunication      = PetscSFLinkPutDataEnd_NVSHMEM;
772     link->PostUnpack               = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
773   }
774 
775   cerr = cudaDeviceGetStreamPriorityRange(NULL,&greatestPriority);CHKERRCUDA(cerr);
776   cerr = cudaStreamCreateWithPriority(&link->remoteCommStream,cudaStreamNonBlocking,greatestPriority);CHKERRCUDA(cerr);
777 
778   cerr = cudaEventCreateWithFlags(&link->dataReady,cudaEventDisableTiming);CHKERRCUDA(cerr);
779   cerr = cudaEventCreateWithFlags(&link->endRemoteComm,cudaEventDisableTiming);CHKERRCUDA(cerr);
780 
781 found:
782   if (rootdirect[PETSCSF_REMOTE]) {
783     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)rootdata + bas->rootstart[PETSCSF_REMOTE]*link->unitbytes;
784   } else {
785     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
786       ierr = PetscNvshmemMalloc(bas->rootbuflen_rmax*link->unitbytes,(void**)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
787     }
788     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
789   }
790 
791   if (leafdirect[PETSCSF_REMOTE]) {
792     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)leafdata + sf->leafstart[PETSCSF_REMOTE]*link->unitbytes;
793   } else {
794     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
795       ierr = PetscNvshmemMalloc(sf->leafbuflen_rmax*link->unitbytes,(void**)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
796     }
797     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
798   }
799 
800   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
801   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
802   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
803   link->leafdata                   = leafdata;
804   link->next                       = bas->inuse;
805   bas->inuse                       = link;
806   *mylink                          = link;
807   PetscFunctionReturn(0);
808 }
809 
810 #if defined(PETSC_USE_REAL_SINGLE)
811 PetscErrorCode PetscNvshmemSum(PetscInt count,float *dst,const float *src)
812 {
813   PetscErrorCode    ierr;
814   PetscMPIInt       num; /* Assume nvshmem's int is MPI's int */
815 
816   PetscFunctionBegin;
817   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
818   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
819   PetscFunctionReturn(0);
820 }
821 
822 PetscErrorCode PetscNvshmemMax(PetscInt count,float *dst,const float *src)
823 {
824   PetscErrorCode    ierr;
825   PetscMPIInt       num;
826 
827   PetscFunctionBegin;
828   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
829   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
830   PetscFunctionReturn(0);
831 }
832 #elif defined(PETSC_USE_REAL_DOUBLE)
833 PetscErrorCode PetscNvshmemSum(PetscInt count,double *dst,const double *src)
834 {
835   PetscErrorCode    ierr;
836   PetscMPIInt       num;
837 
838   PetscFunctionBegin;
839   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
840   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
841   PetscFunctionReturn(0);
842 }
843 
844 PetscErrorCode PetscNvshmemMax(PetscInt count,double *dst,const double *src)
845 {
846   PetscErrorCode    ierr;
847   PetscMPIInt       num;
848 
849   PetscFunctionBegin;
850   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
851   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
852   PetscFunctionReturn(0);
853 }
854 #endif
855 
856