xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision daa037dfd3c3bec8dc8659548d2b20b07c1dc6de)
1 #include <petsc/private/cudavecimpl.h>
2 #include <../src/vec/is/sf/impls/basic/sfpack.h>
3 #include <mpi.h>
4 #include <nvshmem.h>
5 #include <nvshmemx.h>
6 
7 PetscErrorCode PetscNvshmemInitializeCheck(void)
8 {
9   PetscFunctionBegin;
10   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
11     nvshmemx_init_attr_t attr;
12     attr.mpi_comm = &PETSC_COMM_WORLD;
13     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
14     PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM,&attr));
15     PetscNvshmemInitialized = PETSC_TRUE;
16     PetscBeganNvshmem       = PETSC_TRUE;
17   }
18   PetscFunctionReturn(0);
19 }
20 
21 PetscErrorCode PetscNvshmemMalloc(size_t size, void** ptr)
22 {
23   PetscFunctionBegin;
24   PetscCall(PetscNvshmemInitializeCheck());
25   *ptr = nvshmem_malloc(size);
26   PetscCheckFalse(!*ptr,PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_malloc() failed to allocate %zu bytes",size);
27   PetscFunctionReturn(0);
28 }
29 
30 PetscErrorCode PetscNvshmemCalloc(size_t size, void**ptr)
31 {
32   PetscFunctionBegin;
33   PetscCall(PetscNvshmemInitializeCheck());
34   *ptr = nvshmem_calloc(size,1);
35   PetscCheckFalse(!*ptr,PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_calloc() failed to allocate %zu bytes",size);
36   PetscFunctionReturn(0);
37 }
38 
39 PetscErrorCode PetscNvshmemFree_Private(void* ptr)
40 {
41   PetscFunctionBegin;
42   nvshmem_free(ptr);
43   PetscFunctionReturn(0);
44 }
45 
46 PetscErrorCode PetscNvshmemFinalize(void)
47 {
48   PetscFunctionBegin;
49   nvshmem_finalize();
50   PetscFunctionReturn(0);
51 }
52 
53 /* Free nvshmem related fields in the SF */
54 PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
55 {
56   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
57 
58   PetscFunctionBegin;
59   PetscCall(PetscFree2(bas->leafsigdisp,bas->leafbufdisp));
60   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafbufdisp_d));
61   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafsigdisp_d));
62   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->iranks_d));
63   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->ioffset_d));
64 
65   PetscCall(PetscFree2(sf->rootsigdisp,sf->rootbufdisp));
66   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootbufdisp_d));
67   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootsigdisp_d));
68   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->ranks_d));
69   PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->roffset_d));
70   PetscFunctionReturn(0);
71 }
72 
73 /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
74 static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
75 {
76   cudaError_t    cerr;
77   PetscSF_Basic  *bas = (PetscSF_Basic*)sf->data;
78   PetscInt       i,nRemoteRootRanks,nRemoteLeafRanks;
79   PetscMPIInt    tag;
80   MPI_Comm       comm;
81   MPI_Request    *rootreqs,*leafreqs;
82   PetscInt       tmp,stmp[4],rtmp[4]; /* tmps for send/recv buffers */
83 
84   PetscFunctionBegin;
85   PetscCall(PetscObjectGetComm((PetscObject)sf,&comm));
86   PetscCall(PetscObjectGetNewTag((PetscObject)sf,&tag));
87 
88   nRemoteRootRanks      = sf->nranks-sf->ndranks;
89   nRemoteLeafRanks      = bas->niranks-bas->ndiranks;
90   sf->nRemoteRootRanks  = nRemoteRootRanks;
91   bas->nRemoteLeafRanks = nRemoteLeafRanks;
92 
93   PetscCall(PetscMalloc2(nRemoteLeafRanks,&rootreqs,nRemoteRootRanks,&leafreqs));
94 
95   stmp[0] = nRemoteRootRanks;
96   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
97   stmp[2] = nRemoteLeafRanks;
98   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
99 
100   PetscCall(MPIU_Allreduce(stmp,rtmp,4,MPIU_INT,MPI_MAX,comm));
101 
102   sf->nRemoteRootRanksMax   = rtmp[0];
103   sf->leafbuflen_rmax       = rtmp[1];
104   bas->nRemoteLeafRanksMax  = rtmp[2];
105   bas->rootbuflen_rmax      = rtmp[3];
106 
107   /* Total four rounds of MPI communications to set up the nvshmem fields */
108 
109   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
110   PetscCall(PetscMalloc2(nRemoteRootRanks,&sf->rootsigdisp,nRemoteRootRanks,&sf->rootbufdisp));
111   for (i=0; i<nRemoteRootRanks; i++) PetscCallMPI(MPI_Irecv(&sf->rootsigdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i])); /* Leaves recv */
112   for (i=0; i<nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm)); /* Roots send. Note i changes, so we use MPI_Send. */
113   PetscCallMPI(MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE));
114 
115   for (i=0; i<nRemoteRootRanks; i++) PetscCallMPI(MPI_Irecv(&sf->rootbufdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i])); /* Leaves recv */
116   for (i=0; i<nRemoteLeafRanks; i++) {
117     tmp  = bas->ioffset[i+bas->ndiranks] - bas->ioffset[bas->ndiranks];
118     PetscCallMPI(MPI_Send(&tmp,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm));  /* Roots send. Note tmp changes, so we use MPI_Send. */
119   }
120   PetscCallMPI(MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE));
121 
122   PetscCallCUDA(cudaMalloc((void**)&sf->rootbufdisp_d,nRemoteRootRanks*sizeof(PetscInt)));
123   PetscCallCUDA(cudaMalloc((void**)&sf->rootsigdisp_d,nRemoteRootRanks*sizeof(PetscInt)));
124   PetscCallCUDA(cudaMalloc((void**)&sf->ranks_d,nRemoteRootRanks*sizeof(PetscMPIInt)));
125   PetscCallCUDA(cudaMalloc((void**)&sf->roffset_d,(nRemoteRootRanks+1)*sizeof(PetscInt)));
126 
127   PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d,sf->rootbufdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
128   PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d,sf->rootsigdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
129   PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d,sf->ranks+sf->ndranks,nRemoteRootRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
130   PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d,sf->roffset+sf->ndranks,(nRemoteRootRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
131 
132   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
133   PetscCall(PetscMalloc2(nRemoteLeafRanks,&bas->leafsigdisp,nRemoteLeafRanks,&bas->leafbufdisp));
134   for (i=0; i<nRemoteLeafRanks; i++) PetscCallMPI(MPI_Irecv(&bas->leafsigdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]));
135   for (i=0; i<nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm));
136   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE));
137 
138   for (i=0; i<nRemoteLeafRanks; i++) PetscCallMPI(MPI_Irecv(&bas->leafbufdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]));
139   for (i=0; i<nRemoteRootRanks; i++) {
140     tmp  = sf->roffset[i+sf->ndranks] - sf->roffset[sf->ndranks];
141     PetscCallMPI(MPI_Send(&tmp,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm));
142   }
143   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE));
144 
145   PetscCallCUDA(cudaMalloc((void**)&bas->leafbufdisp_d,nRemoteLeafRanks*sizeof(PetscInt)));
146   PetscCallCUDA(cudaMalloc((void**)&bas->leafsigdisp_d,nRemoteLeafRanks*sizeof(PetscInt)));
147   PetscCallCUDA(cudaMalloc((void**)&bas->iranks_d,nRemoteLeafRanks*sizeof(PetscMPIInt)));
148   PetscCallCUDA(cudaMalloc((void**)&bas->ioffset_d,(nRemoteLeafRanks+1)*sizeof(PetscInt)));
149 
150   PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d,bas->leafbufdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
151   PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d,bas->leafsigdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
152   PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d,bas->iranks+bas->ndiranks,nRemoteLeafRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
153   PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d,bas->ioffset+bas->ndiranks,(nRemoteLeafRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream));
154 
155   PetscCall(PetscFree2(rootreqs,leafreqs));
156   PetscFunctionReturn(0);
157 }
158 
159 PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,PetscBool *use_nvshmem)
160 {
161   MPI_Comm         comm;
162   PetscBool        isBasic;
163   PetscMPIInt      result = MPI_UNEQUAL;
164 
165   PetscFunctionBegin;
166   PetscCall(PetscObjectGetComm((PetscObject)sf,&comm));
167   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
168      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
169   */
170   sf->checked_nvshmem_eligibility = PETSC_TRUE;
171   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
172     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
173     PetscCall(PetscObjectTypeCompare((PetscObject)sf,PETSCSFBASIC,&isBasic));
174     if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD,comm,&result));
175     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
176 
177     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
178        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
179        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
180     */
181     if (sf->use_nvshmem) {
182       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
183       PetscCallMPI(MPI_Allreduce(MPI_IN_PLACE,&hasNullRank,1,MPIU_INT,MPI_LOR,comm));
184       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
185     }
186     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
187   }
188 
189   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
190   if (sf->use_nvshmem) {
191     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
192     PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
193    #if defined(PETSC_USE_DEBUG)  /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
194     PetscCallMPI(MPI_Allreduce(&oneCuda,&allCuda,1,MPIU_INT,MPI_LAND,comm));
195     PetscCheckFalse(allCuda != oneCuda,comm,PETSC_ERR_SUP,"root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
196    #endif
197     if (allCuda) {
198       PetscCall(PetscNvshmemInitializeCheck());
199       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
200         PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
201         sf->setup_nvshmem = PETSC_TRUE;
202       }
203       *use_nvshmem = PETSC_TRUE;
204     } else {
205       *use_nvshmem = PETSC_FALSE;
206     }
207   } else {
208     *use_nvshmem = PETSC_FALSE;
209   }
210   PetscFunctionReturn(0);
211 }
212 
213 /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
214 static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
215 {
216   cudaError_t    cerr;
217   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
218   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
219 
220   PetscFunctionBegin;
221   if (buflen) {
222     PetscCallCUDA(cudaEventRecord(link->dataReady,link->stream));
223     PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream,link->dataReady,0));
224   }
225   PetscFunctionReturn(0);
226 }
227 
228 /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
229 static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
230 {
231   cudaError_t    cerr;
232   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
233   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
234 
235   PetscFunctionBegin;
236   /* If unpack to non-null device buffer, build the endRemoteComm dependance */
237   if (buflen) {
238     PetscCallCUDA(cudaEventRecord(link->endRemoteComm,link->remoteCommStream));
239     PetscCallCUDA(cudaStreamWaitEvent(link->stream,link->endRemoteComm,0));
240   }
241   PetscFunctionReturn(0);
242 }
243 
244 /* Send/Put signals to remote ranks
245 
246  Input parameters:
247   + n        - Number of remote ranks
248   . sig      - Signal address in symmetric heap
249   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
250   . ranks    - remote ranks
251   - newval   - Set signals to this value
252 */
253 __global__ static void NvshmemSendSignals(PetscInt n,uint64_t *sig,PetscInt *sigdisp,PetscMPIInt *ranks,uint64_t newval)
254 {
255   int i = blockIdx.x*blockDim.x + threadIdx.x;
256 
257   /* Each thread puts one remote signal */
258   if (i < n) nvshmemx_uint64_signal(sig+sigdisp[i],newval,ranks[i]);
259 }
260 
261 /* Wait until local signals equal to the expected value and then set them to a new value
262 
263  Input parameters:
264   + n        - Number of signals
265   . sig      - Local signal address
266   . expval   - expected value
267   - newval   - Set signals to this new value
268 */
269 __global__ static void NvshmemWaitSignals(PetscInt n,uint64_t *sig,uint64_t expval,uint64_t newval)
270 {
271 #if 0
272   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
273   int i = blockIdx.x*blockDim.x + threadIdx.x;
274   if (i < n) {
275     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
276     sig[i] = newval;
277   }
278 #else
279   nvshmem_uint64_wait_until_all(sig,n,NULL/*no mask*/,NVSHMEM_CMP_EQ,expval);
280   for (int i=0; i<n; i++) sig[i] = newval;
281 #endif
282 }
283 
284 /* ===========================================================================================================
285 
286    A set of routines to support receiver initiated communication using the get method
287 
288     The getting protocol is:
289 
290     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
291     All signal variables have an initial value 0.
292 
293     Sender:                                 |  Receiver:
294   1.  Wait ssig be 0, then set it to 1
295   2.  Pack data into stand alone sbuf       |
296   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
297                                             |   2. Get data from remote sbuf to local rbuf
298                                             |   3. Put 1 to sender's ssig
299                                             |   4. Unpack data from local rbuf
300    ===========================================================================================================*/
301 /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
302    Sender waits for signals (from receivers) indicating receivers have finished getting data
303 */
304 PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
305 {
306   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
307   uint64_t          *sig;
308   PetscInt          n;
309 
310   PetscFunctionBegin;
311   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
312     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
313     n   = bas->nRemoteLeafRanks;
314   } else { /* LEAF2ROOT */
315     sig = link->leafSendSig;
316     n   = sf->nRemoteRootRanks;
317   }
318 
319   if (n) {
320     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(n,sig,0,1); /* wait the signals to be 0, then set them to 1 */
321     PetscCallCUDA(cudaGetLastError());
322   }
323   PetscFunctionReturn(0);
324 }
325 
326 /* n thread blocks. Each takes in charge one remote rank */
327 __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt *srcranks,const char *src,PetscInt *srcdisp,char *dst,PetscInt *dstdisp,PetscInt unitbytes)
328 {
329   int               bid = blockIdx.x;
330   PetscMPIInt       pe  = srcranks[bid];
331 
332   if (!nvshmem_ptr(src,pe)) {
333     PetscInt nelems = (dstdisp[bid+1]-dstdisp[bid])*unitbytes;
334     nvshmem_getmem_nbi(dst+(dstdisp[bid]-dstdisp[0])*unitbytes,src+srcdisp[bid]*unitbytes,nelems,pe);
335   }
336 }
337 
338 /* Start communication -- Get data in the given direction */
339 PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
340 {
341   cudaError_t       cerr;
342   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
343 
344   PetscInt          nsrcranks,ndstranks,nLocallyAccessible = 0;
345 
346   char              *src,*dst;
347   PetscInt          *srcdisp_h,*dstdisp_h;
348   PetscInt          *srcdisp_d,*dstdisp_d;
349   PetscMPIInt       *srcranks_h;
350   PetscMPIInt       *srcranks_d,*dstranks_d;
351   uint64_t          *dstsig;
352   PetscInt          *dstsigdisp_d;
353 
354   PetscFunctionBegin;
355   PetscCall(PetscSFLinkBuildDependenceBegin(sf,link,direction));
356   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
357     nsrcranks    = sf->nRemoteRootRanks;
358     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
359 
360     srcdisp_h    = sf->rootbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
361     srcdisp_d    = sf->rootbufdisp_d;
362     srcranks_h   = sf->ranks+sf->ndranks; /* my (remote) root ranks */
363     srcranks_d   = sf->ranks_d;
364 
365     ndstranks    = bas->nRemoteLeafRanks;
366     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
367 
368     dstdisp_h    = sf->roffset+sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
369     dstdisp_d    = sf->roffset_d;
370     dstranks_d   = bas->iranks_d; /* my (remote) leaf ranks */
371 
372     dstsig       = link->leafRecvSig;
373     dstsigdisp_d = bas->leafsigdisp_d;
374   } else { /* src is leaf, dst is root; we will move data from src to dst */
375     nsrcranks    = bas->nRemoteLeafRanks;
376     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
377 
378     srcdisp_h    = bas->leafbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
379     srcdisp_d    = bas->leafbufdisp_d;
380     srcranks_h   = bas->iranks+bas->ndiranks; /* my (remote) root ranks */
381     srcranks_d   = bas->iranks_d;
382 
383     ndstranks    = sf->nRemoteRootRanks;
384     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
385 
386     dstdisp_h    = bas->ioffset+bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
387     dstdisp_d    = bas->ioffset_d;
388     dstranks_d   = sf->ranks_d; /* my (remote) root ranks */
389 
390     dstsig       = link->rootRecvSig;
391     dstsigdisp_d = sf->rootsigdisp_d;
392   }
393 
394   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
395   if (ndstranks) {
396     NvshmemSendSignals<<<(ndstranks+255)/256,256,0,link->remoteCommStream>>>(ndstranks,dstsig,dstsigdisp_d,dstranks_d,1); /* set signals to 1 */
397     PetscCallCUDA(cudaGetLastError());
398   }
399 
400   /* dst waits for signals (permissions) from src ranks to start getting data */
401   if (nsrcranks) {
402     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(nsrcranks,dstsig,1,0); /* wait the signals to be 1, then set them to 0 */
403     PetscCallCUDA(cudaGetLastError());
404   }
405 
406   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
407 
408   /* Count number of locally accessible src ranks, which should be a small number */
409   for (int i=0; i<nsrcranks; i++) {if (nvshmem_ptr(src,srcranks_h[i])) nLocallyAccessible++;}
410 
411   /* Get data from remotely accessible PEs */
412   if (nLocallyAccessible < nsrcranks) {
413     GetDataFromRemotelyAccessible<<<nsrcranks,1,0,link->remoteCommStream>>>(nsrcranks,srcranks_d,src,srcdisp_d,dst,dstdisp_d,link->unitbytes);
414     PetscCallCUDA(cudaGetLastError());
415   }
416 
417   /* Get data from locally accessible PEs */
418   if (nLocallyAccessible) {
419     for (int i=0; i<nsrcranks; i++) {
420       int pe = srcranks_h[i];
421       if (nvshmem_ptr(src,pe)) {
422         size_t nelems = (dstdisp_h[i+1]-dstdisp_h[i])*link->unitbytes;
423         nvshmemx_getmem_nbi_on_stream(dst+(dstdisp_h[i]-dstdisp_h[0])*link->unitbytes,src+srcdisp_h[i]*link->unitbytes,nelems,pe,link->remoteCommStream);
424       }
425     }
426   }
427   PetscFunctionReturn(0);
428 }
429 
430 /* Finish the communication (can be done before Unpack)
431    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
432 */
433 PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
434 {
435   cudaError_t       cerr;
436   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
437   uint64_t          *srcsig;
438   PetscInt          nsrcranks,*srcsigdisp;
439   PetscMPIInt       *srcranks;
440 
441   PetscFunctionBegin;
442   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
443     nsrcranks   = sf->nRemoteRootRanks;
444     srcsig      = link->rootSendSig;     /* I want to set their root signal */
445     srcsigdisp  = sf->rootsigdisp_d;     /* offset of each root signal */
446     srcranks    = sf->ranks_d;           /* ranks of the n root ranks */
447   } else { /* LEAF2ROOT, root ranks are getting data */
448     nsrcranks   = bas->nRemoteLeafRanks;
449     srcsig      = link->leafSendSig;
450     srcsigdisp  = bas->leafsigdisp_d;
451     srcranks    = bas->iranks_d;
452   }
453 
454   if (nsrcranks) {
455     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
456     PetscCallCUDA(cudaGetLastError());
457     NvshmemSendSignals<<<(nsrcranks+511)/512,512,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp,srcranks,0); /* set signals to 0 */
458     PetscCallCUDA(cudaGetLastError());
459   }
460   PetscCall(PetscSFLinkBuildDependenceEnd(sf,link,direction));
461   PetscFunctionReturn(0);
462 }
463 
464 /* ===========================================================================================================
465 
466    A set of routines to support sender initiated communication using the put-based method (the default)
467 
468     The putting protocol is:
469 
470     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
471     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
472     is in nvshmem space.
473 
474     Sender:                                 |  Receiver:
475                                             |
476   1.  Pack data into sbuf                   |
477   2.  Wait ssig be 0, then set it to 1      |
478   3.  Put data to remote stand-alone rbuf   |
479   4.  Fence // make sure 5 happens after 3  |
480   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
481                                             |   2. Unpack data from local rbuf
482                                             |   3. Put 0 to sender's ssig
483    ===========================================================================================================*/
484 
485 /* n thread blocks. Each takes in charge one remote rank */
486 __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,char *dst,PetscInt *dstdisp,const char *src,PetscInt *srcdisp,uint64_t *srcsig,PetscInt unitbytes)
487 {
488   int               bid = blockIdx.x;
489   PetscMPIInt       pe  = dstranks[bid];
490 
491   if (!nvshmem_ptr(dst,pe)) {
492     PetscInt nelems = (srcdisp[bid+1]-srcdisp[bid])*unitbytes;
493     nvshmem_uint64_wait_until(srcsig+bid,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
494     srcsig[bid] = 1;
495     nvshmem_putmem_nbi(dst+dstdisp[bid]*unitbytes,src+(srcdisp[bid]-srcdisp[0])*unitbytes,nelems,pe);
496   }
497 }
498 
499 /* one-thread kernel, which takes in charge all locally accesible */
500 __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *srcsig,const char *dst)
501 {
502   for (int i=0; i<ndstranks; i++) {
503     int pe = dstranks[i];
504     if (nvshmem_ptr(dst,pe)) {
505       nvshmem_uint64_wait_until(srcsig+i,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
506       srcsig[i] = 1;
507     }
508   }
509 }
510 
511 /* Put data in the given direction  */
512 PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
513 {
514   cudaError_t       cerr;
515   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
516   PetscInt          ndstranks,nLocallyAccessible = 0;
517   char              *src,*dst;
518   PetscInt          *srcdisp_h,*dstdisp_h;
519   PetscInt          *srcdisp_d,*dstdisp_d;
520   PetscMPIInt       *dstranks_h;
521   PetscMPIInt       *dstranks_d;
522   uint64_t          *srcsig;
523 
524   PetscFunctionBegin;
525   PetscCall(PetscSFLinkBuildDependenceBegin(sf,link,direction));
526   if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf  */
527     ndstranks    = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */
528     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
529     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
530 
531     srcdisp_h    = bas->ioffset+bas->ndiranks;  /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
532     srcdisp_d    = bas->ioffset_d;
533     srcsig       = link->rootSendSig;
534 
535     dstdisp_h    = bas->leafbufdisp;            /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
536     dstdisp_d    = bas->leafbufdisp_d;
537     dstranks_h   = bas->iranks+bas->ndiranks;   /* remote leaf ranks */
538     dstranks_d   = bas->iranks_d;
539   } else { /* put data in leafbuf to rootbuf */
540     ndstranks    = sf->nRemoteRootRanks;
541     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
542     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
543 
544     srcdisp_h    = sf->roffset+sf->ndranks; /* offsets of leafbuf */
545     srcdisp_d    = sf->roffset_d;
546     srcsig       = link->leafSendSig;
547 
548     dstdisp_h    = sf->rootbufdisp;         /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
549     dstdisp_d    = sf->rootbufdisp_d;
550     dstranks_h   = sf->ranks+sf->ndranks;   /* remote root ranks */
551     dstranks_d   = sf->ranks_d;
552   }
553 
554   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
555 
556   /* Count number of locally accessible neighbors, which should be a small number */
557   for (int i=0; i<ndstranks; i++) {if (nvshmem_ptr(dst,dstranks_h[i])) nLocallyAccessible++;}
558 
559   /* For remotely accessible PEs, send data to them in one kernel call */
560   if (nLocallyAccessible < ndstranks) {
561     WaitAndPutDataToRemotelyAccessible<<<ndstranks,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,dst,dstdisp_d,src,srcdisp_d,srcsig,link->unitbytes);
562     PetscCallCUDA(cudaGetLastError());
563   }
564 
565   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
566   if (nLocallyAccessible) {
567     WaitSignalsFromLocallyAccessible<<<1,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,srcsig,dst);
568     for (int i=0; i<ndstranks; i++) {
569       int pe = dstranks_h[i];
570       if (nvshmem_ptr(dst,pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
571         size_t nelems = (srcdisp_h[i+1]-srcdisp_h[i])*link->unitbytes;
572          /* Initiate the nonblocking communication */
573         nvshmemx_putmem_nbi_on_stream(dst+dstdisp_h[i]*link->unitbytes,src+(srcdisp_h[i]-srcdisp_h[0])*link->unitbytes,nelems,pe,link->remoteCommStream);
574       }
575     }
576   }
577 
578   if (nLocallyAccessible) {
579     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
580   }
581   PetscFunctionReturn(0);
582 }
583 
584 /* A one-thread kernel. The thread takes in charge all remote PEs */
585 __global__ static void PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *dstsig,PetscInt *dstsigdisp)
586 {
587   /* TODO: Shall we finished the non-blocking remote puts? */
588 
589   /* 1. Send a signal to each dst rank */
590 
591   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
592      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
593   */
594   for (int i=0; i<ndstranks; i++) {nvshmemx_uint64_signal(dstsig+dstsigdisp[i],1,dstranks[i]);} /* set sig to 1 */
595 
596   /* 2. Wait for signals from src ranks (if any) */
597   if (nsrcranks) {
598     nvshmem_uint64_wait_until_all(dstsig,nsrcranks,NULL/*no mask*/,NVSHMEM_CMP_EQ,1); /* wait sigs to be 1, then set them to 0 */
599     for (int i=0; i<nsrcranks; i++) dstsig[i] = 0;
600   }
601 }
602 
603 /* Finish the communication -- A receiver waits until it can access its receive buffer */
604 PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
605 {
606   cudaError_t       cerr;
607   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
608   PetscMPIInt       *dstranks;
609   uint64_t          *dstsig;
610   PetscInt          nsrcranks,ndstranks,*dstsigdisp;
611 
612   PetscFunctionBegin;
613   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
614     nsrcranks    = sf->nRemoteRootRanks;
615 
616     ndstranks    = bas->nRemoteLeafRanks;
617     dstranks     = bas->iranks_d;       /* leaf ranks */
618     dstsig       = link->leafRecvSig;   /* I will set my leaf ranks's RecvSig */
619     dstsigdisp   = bas->leafsigdisp_d;  /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
620   } else { /* LEAF2ROOT */
621     nsrcranks    = bas->nRemoteLeafRanks;
622 
623     ndstranks    = sf->nRemoteRootRanks;
624     dstranks     = sf->ranks_d;
625     dstsig       = link->rootRecvSig;
626     dstsigdisp   = sf->rootsigdisp_d;
627   }
628 
629   if (nsrcranks || ndstranks) {
630     PutDataEnd<<<1,1,0,link->remoteCommStream>>>(nsrcranks,ndstranks,dstranks,dstsig,dstsigdisp);
631     PetscCallCUDA(cudaGetLastError());
632   }
633   PetscCall(PetscSFLinkBuildDependenceEnd(sf,link,direction));
634   PetscFunctionReturn(0);
635 }
636 
637 /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
638 PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
639 {
640   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
641   uint64_t          *srcsig;
642   PetscInt          nsrcranks,*srcsigdisp_d;
643   PetscMPIInt       *srcranks_d;
644 
645   PetscFunctionBegin;
646   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
647     nsrcranks    = sf->nRemoteRootRanks;
648     srcsig       = link->rootSendSig;      /* I want to set their send signals */
649     srcsigdisp_d = sf->rootsigdisp_d;      /* offset of each root signal */
650     srcranks_d   = sf->ranks_d;            /* ranks of the n root ranks */
651   } else { /* LEAF2ROOT */
652     nsrcranks    = bas->nRemoteLeafRanks;
653     srcsig       = link->leafSendSig;
654     srcsigdisp_d = bas->leafsigdisp_d;
655     srcranks_d   = bas->iranks_d;
656   }
657 
658   if (nsrcranks) {
659     NvshmemSendSignals<<<(nsrcranks+255)/256,256,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp_d,srcranks_d,0); /* Set remote signals to 0 */
660     PetscCallCUDA(cudaGetLastError());
661   }
662   PetscFunctionReturn(0);
663 }
664 
665 /* Destructor when the link uses nvshmem for communication */
666 static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)
667 {
668   cudaError_t       cerr;
669 
670   PetscFunctionBegin;
671   PetscCallCUDA(cudaEventDestroy(link->dataReady));
672   PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
673   PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));
674 
675   /* nvshmem does not need buffers on host, which should be NULL */
676   PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
677   PetscCall(PetscNvshmemFree(link->leafSendSig));
678   PetscCall(PetscNvshmemFree(link->leafRecvSig));
679   PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
680   PetscCall(PetscNvshmemFree(link->rootSendSig));
681   PetscCall(PetscNvshmemFree(link->rootRecvSig));
682   PetscFunctionReturn(0);
683 }
684 
685 PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink *mylink)
686 {
687   cudaError_t       cerr;
688   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
689   PetscSFLink       *p,link;
690   PetscBool         match,rootdirect[2],leafdirect[2];
691   int               greatestPriority;
692 
693   PetscFunctionBegin;
694   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
695      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM.
696   */
697   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
698     if (sf->use_nvshmem_get) {
699       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
700       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
701     } else {
702       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
703       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;  /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
704     }
705   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
706     if (sf->use_nvshmem_get) {
707       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
708       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
709     } else {
710       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
711       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
712     }
713   } else { /* PETSCSF_FETCH */
714     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
715     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
716   }
717 
718   /* Look for free nvshmem links in cache */
719   for (p=&bas->avail; (link=*p); p=&link->next) {
720     if (link->use_nvshmem) {
721       PetscCall(MPIPetsc_Type_compare(unit,link->unit,&match));
722       if (match) {
723         *p = link->next; /* Remove from available list */
724         goto found;
725       }
726     }
727   }
728   PetscCall(PetscNew(&link));
729   PetscCall(PetscSFLinkSetUp_Host(sf,link,unit)); /* Compute link->unitbytes, dup link->unit etc. */
730   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf,link,unit)); /* Setup pack routines, streams etc */
731  #if defined(PETSC_HAVE_KOKKOS)
732   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf,link,unit));
733  #endif
734 
735   link->rootdirect[PETSCSF_LOCAL]  = PETSC_TRUE; /* For the local part we directly use root/leafdata */
736   link->leafdirect[PETSCSF_LOCAL]  = PETSC_TRUE;
737 
738   /* Init signals to zero */
739   if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootSendSig));
740   if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootRecvSig));
741   if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafSendSig));
742   if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafRecvSig));
743 
744   link->use_nvshmem                = PETSC_TRUE;
745   link->rootmtype                  = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
746   link->leafmtype                  = PETSC_MEMTYPE_DEVICE;
747   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
748   link->Destroy                    = PetscSFLinkDestroy_NVSHMEM;
749   if (sf->use_nvshmem_get) { /* get-based protocol */
750     link->PrePack                  = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
751     link->StartCommunication       = PetscSFLinkGetDataBegin_NVSHMEM;
752     link->FinishCommunication      = PetscSFLinkGetDataEnd_NVSHMEM;
753   } else { /* put-based protocol */
754     link->StartCommunication       = PetscSFLinkPutDataBegin_NVSHMEM;
755     link->FinishCommunication      = PetscSFLinkPutDataEnd_NVSHMEM;
756     link->PostUnpack               = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
757   }
758 
759   PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL,&greatestPriority));
760   PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream,cudaStreamNonBlocking,greatestPriority));
761 
762   PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady,cudaEventDisableTiming));
763   PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm,cudaEventDisableTiming));
764 
765 found:
766   if (rootdirect[PETSCSF_REMOTE]) {
767     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)rootdata + bas->rootstart[PETSCSF_REMOTE]*link->unitbytes;
768   } else {
769     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
770       PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax*link->unitbytes,(void**)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
771     }
772     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
773   }
774 
775   if (leafdirect[PETSCSF_REMOTE]) {
776     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)leafdata + sf->leafstart[PETSCSF_REMOTE]*link->unitbytes;
777   } else {
778     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
779       PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax*link->unitbytes,(void**)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
780     }
781     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
782   }
783 
784   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
785   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
786   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
787   link->leafdata                   = leafdata;
788   link->next                       = bas->inuse;
789   bas->inuse                       = link;
790   *mylink                          = link;
791   PetscFunctionReturn(0);
792 }
793 
794 #if defined(PETSC_USE_REAL_SINGLE)
795 PetscErrorCode PetscNvshmemSum(PetscInt count,float *dst,const float *src)
796 {
797   PetscMPIInt       num; /* Assume nvshmem's int is MPI's int */
798 
799   PetscFunctionBegin;
800   PetscCall(PetscMPIIntCast(count,&num));
801   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
802   PetscFunctionReturn(0);
803 }
804 
805 PetscErrorCode PetscNvshmemMax(PetscInt count,float *dst,const float *src)
806 {
807   PetscMPIInt       num;
808 
809   PetscFunctionBegin;
810   PetscCall(PetscMPIIntCast(count,&num));
811   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
812   PetscFunctionReturn(0);
813 }
814 #elif defined(PETSC_USE_REAL_DOUBLE)
815 PetscErrorCode PetscNvshmemSum(PetscInt count,double *dst,const double *src)
816 {
817   PetscMPIInt       num;
818 
819   PetscFunctionBegin;
820   PetscCall(PetscMPIIntCast(count,&num));
821   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
822   PetscFunctionReturn(0);
823 }
824 
825 PetscErrorCode PetscNvshmemMax(PetscInt count,double *dst,const double *src)
826 {
827   PetscMPIInt       num;
828 
829   PetscFunctionBegin;
830   PetscCall(PetscMPIIntCast(count,&num));
831   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
832   PetscFunctionReturn(0);
833 }
834 #endif
835