xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision 2ff79c18c26c94ed8cb599682f680f231dca6444)
1 #include <petsc/private/cudavecimpl.h>
2 #include <../src/vec/is/sf/impls/basic/sfpack.h>
3 #include <mpi.h>
4 #include <nvshmem.h>
5 #include <nvshmemx.h>
6 
7 PetscErrorCode PetscNvshmemInitializeCheck(void)
8 {
9   PetscFunctionBegin;
10   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
11     nvshmemx_init_attr_t attr;
12     attr.mpi_comm = &PETSC_COMM_WORLD;
13     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
14     PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr));
15     PetscNvshmemInitialized = PETSC_TRUE;
16     PetscBeganNvshmem       = PETSC_TRUE;
17   }
18   PetscFunctionReturn(PETSC_SUCCESS);
19 }
20 
21 PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr)
22 {
23   PetscFunctionBegin;
24   PetscCall(PetscNvshmemInitializeCheck());
25   *ptr = nvshmem_malloc(size);
26   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_malloc() failed to allocate %zu bytes", size);
27   PetscFunctionReturn(PETSC_SUCCESS);
28 }
29 
30 PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr)
31 {
32   PetscFunctionBegin;
33   PetscCall(PetscNvshmemInitializeCheck());
34   *ptr = nvshmem_calloc(size, 1);
35   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_calloc() failed to allocate %zu bytes", size);
36   PetscFunctionReturn(PETSC_SUCCESS);
37 }
38 
39 PetscErrorCode PetscNvshmemFree_Private(void *ptr)
40 {
41   PetscFunctionBegin;
42   nvshmem_free(ptr);
43   PetscFunctionReturn(PETSC_SUCCESS);
44 }
45 
46 PetscErrorCode PetscNvshmemFinalize(void)
47 {
48   PetscFunctionBegin;
49   nvshmem_finalize();
50   PetscFunctionReturn(PETSC_SUCCESS);
51 }
52 
53 /* Free nvshmem related fields in the SF */
54 PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
55 {
56   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
57 
58   PetscFunctionBegin;
59   PetscCall(PetscFree2(bas->leafsigdisp, bas->leafbufdisp));
60   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d));
61   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d));
62   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d));
63   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d));
64 
65   PetscCall(PetscFree2(sf->rootsigdisp, sf->rootbufdisp));
66   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d));
67   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d));
68   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d));
69   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d));
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependent fields) */
74 static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
75 {
76   cudaError_t    cerr;
77   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
78   PetscInt       i, nRemoteRootRanks, nRemoteLeafRanks;
79   PetscMPIInt    tag;
80   MPI_Comm       comm;
81   MPI_Request   *rootreqs, *leafreqs;
82   PetscInt       tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */
83 
84   PetscFunctionBegin;
85   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
86   PetscCall(PetscObjectGetNewTag((PetscObject)sf, &tag));
87 
88   nRemoteRootRanks      = sf->nranks - sf->ndranks;
89   nRemoteLeafRanks      = bas->niranks - bas->ndiranks;
90   sf->nRemoteRootRanks  = nRemoteRootRanks;
91   bas->nRemoteLeafRanks = nRemoteLeafRanks;
92 
93   PetscCall(PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs));
94 
95   stmp[0] = nRemoteRootRanks;
96   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
97   stmp[2] = nRemoteLeafRanks;
98   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
99 
100   PetscCallMPI(MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm));
101 
102   sf->nRemoteRootRanksMax  = rtmp[0];
103   sf->leafbuflen_rmax      = rtmp[1];
104   bas->nRemoteLeafRanksMax = rtmp[2];
105   bas->rootbuflen_rmax     = rtmp[3];
106 
107   /* Total four rounds of MPI communications to set up the nvshmem fields */
108 
109   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
110   PetscCall(PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp));
111   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
112   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm));                              /* Roots send. Note i changes, so we use MPI_Send. */
113   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
114 
115   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
116   for (i = 0; i < nRemoteLeafRanks; i++) {
117     tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
118     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note tmp changes, so we use MPI_Send. */
119   }
120   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
121 
122   PetscCallCUDA(cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
123   PetscCallCUDA(cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
124   PetscCallCUDA(cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt)));
125   PetscCallCUDA(cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt)));
126 
127   PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
128   PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
129   PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
130   PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
131 
132   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
133   PetscCall(PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp));
134   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
135   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
136   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
137 
138   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
139   for (i = 0; i < nRemoteRootRanks; i++) {
140     tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
141     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
142   }
143   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
144 
145   PetscCallCUDA(cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
146   PetscCallCUDA(cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
147   PetscCallCUDA(cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt)));
148   PetscCallCUDA(cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt)));
149 
150   PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
151   PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
152   PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
153   PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
154 
155   PetscCall(PetscFree2(rootreqs, leafreqs));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem)
160 {
161   MPI_Comm    comm;
162   PetscBool   isBasic;
163   PetscMPIInt result = MPI_UNEQUAL;
164 
165   PetscFunctionBegin;
166   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
167   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
168      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
169   */
170   sf->checked_nvshmem_eligibility = PETSC_TRUE;
171   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
172     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
173     PetscCall(PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic));
174     if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result));
175     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
176 
177     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
178        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
179        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
180     */
181     if (sf->use_nvshmem) {
182       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
183       PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm));
184       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
185     }
186     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
187   }
188 
189   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
190   if (sf->use_nvshmem) {
191     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
192     PetscInt allCuda = oneCuda;                                                                                          /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
193 #if defined(PETSC_USE_DEBUG)                                                                                             /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
194     PetscCallMPI(MPIU_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm));
195     PetscCheck(allCuda == oneCuda, comm, PETSC_ERR_SUP, "root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
196 #endif
197     if (allCuda) {
198       PetscCall(PetscNvshmemInitializeCheck());
199       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
200         PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
201         sf->setup_nvshmem = PETSC_TRUE;
202       }
203       *use_nvshmem = PETSC_TRUE;
204     } else {
205       *use_nvshmem = PETSC_FALSE;
206     }
207   } else {
208     *use_nvshmem = PETSC_FALSE;
209   }
210   PetscFunctionReturn(PETSC_SUCCESS);
211 }
212 
213 /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
214 static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
215 {
216   cudaError_t    cerr;
217   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
218   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
219 
220   PetscFunctionBegin;
221   if (buflen) {
222     PetscCallCUDA(cudaEventRecord(link->dataReady, link->stream));
223     PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0));
224   }
225   PetscFunctionReturn(PETSC_SUCCESS);
226 }
227 
228 /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
229 static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
230 {
231   cudaError_t    cerr;
232   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
233   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
234 
235   PetscFunctionBegin;
236   /* If unpack to non-null device buffer, build the endRemoteComm dependence */
237   if (buflen) {
238     PetscCallCUDA(cudaEventRecord(link->endRemoteComm, link->remoteCommStream));
239     PetscCallCUDA(cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0));
240   }
241   PetscFunctionReturn(PETSC_SUCCESS);
242 }
243 
244 /* Send/Put signals to remote ranks
245 
246  Input parameters:
247   + n        - Number of remote ranks
248   . sig      - Signal address in symmetric heap
249   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
250   . ranks    - remote ranks
251   - newval   - Set signals to this value
252 */
253 __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval)
254 {
255   int i = blockIdx.x * blockDim.x + threadIdx.x;
256 
257   /* Each thread puts one remote signal */
258   if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
259 }
260 
261 /* Wait until local signals equal to the expected value and then set them to a new value
262 
263  Input parameters:
264   + n        - Number of signals
265   . sig      - Local signal address
266   . expval   - expected value
267   - newval   - Set signals to this new value
268 */
269 __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval)
270 {
271 #if 0
272   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
273   int i = blockIdx.x*blockDim.x + threadIdx.x;
274   if (i < n) {
275     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
276     sig[i] = newval;
277   }
278 #else
279   nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
280   for (int i = 0; i < n; i++) sig[i] = newval;
281 #endif
282 }
283 
284 /* ===========================================================================================================
285 
286    A set of routines to support receiver initiated communication using the get method
287 
288     The getting protocol is:
289 
290     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
291     All signal variables have an initial value 0.
292 
293     Sender:                                 |  Receiver:
294   1.  Wait ssig be 0, then set it to 1
295   2.  Pack data into stand alone sbuf       |
296   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
297                                             |   2. Get data from remote sbuf to local rbuf
298                                             |   3. Put 1 to sender's ssig
299                                             |   4. Unpack data from local rbuf
300    ===========================================================================================================*/
301 /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
302    Sender waits for signals (from receivers) indicating receivers have finished getting data
303 */
304 static PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
305 {
306   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
307   uint64_t      *sig;
308   PetscInt       n;
309 
310   PetscFunctionBegin;
311   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
312     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
313     n   = bas->nRemoteLeafRanks;
314   } else { /* LEAF2ROOT */
315     sig = link->leafSendSig;
316     n   = sf->nRemoteRootRanks;
317   }
318 
319   if (n) {
320     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
321     PetscCallCUDA(cudaGetLastError());
322   }
323   PetscFunctionReturn(PETSC_SUCCESS);
324 }
325 
326 /* n thread blocks. Each takes in charge one remote rank */
327 __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes)
328 {
329   int         bid = blockIdx.x;
330   PetscMPIInt pe  = srcranks[bid];
331 
332   if (!nvshmem_ptr(src, pe)) {
333     PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
334     nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
335   }
336 }
337 
338 /* Start communication -- Get data in the given direction */
339 static PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
340 {
341   cudaError_t    cerr;
342   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
343 
344   PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;
345 
346   char        *src, *dst;
347   PetscInt    *srcdisp_h, *dstdisp_h;
348   PetscInt    *srcdisp_d, *dstdisp_d;
349   PetscMPIInt *srcranks_h;
350   PetscMPIInt *srcranks_d, *dstranks_d;
351   uint64_t    *dstsig;
352   PetscInt    *dstsigdisp_d;
353 
354   PetscFunctionBegin;
355   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
356   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
357     nsrcranks = sf->nRemoteRootRanks;
358     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
359 
360     srcdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
361     srcdisp_d  = sf->rootbufdisp_d;
362     srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
363     srcranks_d = sf->ranks_d;
364 
365     ndstranks = bas->nRemoteLeafRanks;
366     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
367 
368     dstdisp_h  = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
369     dstdisp_d  = sf->roffset_d;
370     dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */
371 
372     dstsig       = link->leafRecvSig;
373     dstsigdisp_d = bas->leafsigdisp_d;
374   } else { /* src is leaf, dst is root; we will move data from src to dst */
375     nsrcranks = bas->nRemoteLeafRanks;
376     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
377 
378     srcdisp_h  = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
379     srcdisp_d  = bas->leafbufdisp_d;
380     srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
381     srcranks_d = bas->iranks_d;
382 
383     ndstranks = sf->nRemoteRootRanks;
384     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
385 
386     dstdisp_h  = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
387     dstdisp_d  = bas->ioffset_d;
388     dstranks_d = sf->ranks_d; /* my (remote) root ranks */
389 
390     dstsig       = link->rootRecvSig;
391     dstsigdisp_d = sf->rootsigdisp_d;
392   }
393 
394   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
395   if (ndstranks) {
396     NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
397     PetscCallCUDA(cudaGetLastError());
398   }
399 
400   /* dst waits for signals (permissions) from src ranks to start getting data */
401   if (nsrcranks) {
402     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
403     PetscCallCUDA(cudaGetLastError());
404   }
405 
406   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
407 
408   /* Count number of locally accessible src ranks, which should be a small number */
409   for (int i = 0; i < nsrcranks; i++) {
410     if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
411   }
412 
413   /* Get data from remotely accessible PEs */
414   if (nLocallyAccessible < nsrcranks) {
415     GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
416     PetscCallCUDA(cudaGetLastError());
417   }
418 
419   /* Get data from locally accessible PEs */
420   if (nLocallyAccessible) {
421     for (int i = 0; i < nsrcranks; i++) {
422       int pe = srcranks_h[i];
423       if (nvshmem_ptr(src, pe)) {
424         size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
425         nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
426       }
427     }
428   }
429   PetscFunctionReturn(PETSC_SUCCESS);
430 }
431 
432 /* Finish the communication (can be done before Unpack)
433    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
434 */
435 static PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
436 {
437   cudaError_t    cerr;
438   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
439   uint64_t      *srcsig;
440   PetscInt       nsrcranks, *srcsigdisp;
441   PetscMPIInt   *srcranks;
442 
443   PetscFunctionBegin;
444   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
445     nsrcranks  = sf->nRemoteRootRanks;
446     srcsig     = link->rootSendSig; /* I want to set their root signal */
447     srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
448     srcranks   = sf->ranks_d;       /* ranks of the n root ranks */
449   } else {                          /* LEAF2ROOT, root ranks are getting data */
450     nsrcranks  = bas->nRemoteLeafRanks;
451     srcsig     = link->leafSendSig;
452     srcsigdisp = bas->leafsigdisp_d;
453     srcranks   = bas->iranks_d;
454   }
455 
456   if (nsrcranks) {
457     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
458     PetscCallCUDA(cudaGetLastError());
459     NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
460     PetscCallCUDA(cudaGetLastError());
461   }
462   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
463   PetscFunctionReturn(PETSC_SUCCESS);
464 }
465 
466 /* ===========================================================================================================
467 
468    A set of routines to support sender initiated communication using the put-based method (the default)
469 
470     The putting protocol is:
471 
472     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
473     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
474     is in nvshmem space.
475 
476     Sender:                                 |  Receiver:
477                                             |
478   1.  Pack data into sbuf                   |
479   2.  Wait ssig be 0, then set it to 1      |
480   3.  Put data to remote stand-alone rbuf   |
481   4.  Fence // make sure 5 happens after 3  |
482   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
483                                             |   2. Unpack data from local rbuf
484                                             |   3. Put 0 to sender's ssig
485    ===========================================================================================================*/
486 
487 /* n thread blocks. Each takes in charge one remote rank */
488 __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes)
489 {
490   int         bid = blockIdx.x;
491   PetscMPIInt pe  = dstranks[bid];
492 
493   if (!nvshmem_ptr(dst, pe)) {
494     PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
495     nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
496     srcsig[bid] = 1;
497     nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
498   }
499 }
500 
501 /* one-thread kernel, which takes in charge all locally accessible */
502 __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst)
503 {
504   for (int i = 0; i < ndstranks; i++) {
505     int pe = dstranks[i];
506     if (nvshmem_ptr(dst, pe)) {
507       nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
508       srcsig[i] = 1;
509     }
510   }
511 }
512 
513 /* Put data in the given direction  */
514 static PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
515 {
516   cudaError_t    cerr;
517   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
518   PetscInt       ndstranks, nLocallyAccessible = 0;
519   char          *src, *dst;
520   PetscInt      *srcdisp_h, *dstdisp_h;
521   PetscInt      *srcdisp_d, *dstdisp_d;
522   PetscMPIInt   *dstranks_h;
523   PetscMPIInt   *dstranks_d;
524   uint64_t      *srcsig;
525 
526   PetscFunctionBegin;
527   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
528   if (direction == PETSCSF_ROOT2LEAF) {                              /* put data in rootbuf to leafbuf  */
529     ndstranks = bas->nRemoteLeafRanks;                               /* number of (remote) leaf ranks */
530     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
531     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
532 
533     srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
534     srcdisp_d = bas->ioffset_d;
535     srcsig    = link->rootSendSig;
536 
537     dstdisp_h  = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
538     dstdisp_d  = bas->leafbufdisp_d;
539     dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
540     dstranks_d = bas->iranks_d;
541   } else { /* put data in leafbuf to rootbuf */
542     ndstranks = sf->nRemoteRootRanks;
543     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
544     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
545 
546     srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
547     srcdisp_d = sf->roffset_d;
548     srcsig    = link->leafSendSig;
549 
550     dstdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
551     dstdisp_d  = sf->rootbufdisp_d;
552     dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
553     dstranks_d = sf->ranks_d;
554   }
555 
556   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
557 
558   /* Count number of locally accessible neighbors, which should be a small number */
559   for (int i = 0; i < ndstranks; i++) {
560     if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
561   }
562 
563   /* For remotely accessible PEs, send data to them in one kernel call */
564   if (nLocallyAccessible < ndstranks) {
565     WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
566     PetscCallCUDA(cudaGetLastError());
567   }
568 
569   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
570   if (nLocallyAccessible) {
571     WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
572     for (int i = 0; i < ndstranks; i++) {
573       int pe = dstranks_h[i];
574       if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
575         size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
576         /* Initiate the nonblocking communication */
577         nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
578       }
579     }
580   }
581 
582   if (nLocallyAccessible) nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
583   PetscFunctionReturn(PETSC_SUCCESS);
584 }
585 
586 /* A one-thread kernel. The thread takes in charge all remote PEs */
587 __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp)
588 {
589   /* TODO: Shall we finished the non-blocking remote puts? */
590 
591   /* 1. Send a signal to each dst rank */
592 
593   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
594      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
595   */
596   for (int i = 0; i < ndstranks; i++) nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); /* set sig to 1 */
597 
598   /* 2. Wait for signals from src ranks (if any) */
599   if (nsrcranks) {
600     nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
601     for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
602   }
603 }
604 
605 /* Finish the communication -- A receiver waits until it can access its receive buffer */
606 static PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
607 {
608   cudaError_t    cerr;
609   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
610   PetscMPIInt   *dstranks;
611   uint64_t      *dstsig;
612   PetscInt       nsrcranks, ndstranks, *dstsigdisp;
613 
614   PetscFunctionBegin;
615   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
616     nsrcranks = sf->nRemoteRootRanks;
617 
618     ndstranks  = bas->nRemoteLeafRanks;
619     dstranks   = bas->iranks_d;      /* leaf ranks */
620     dstsig     = link->leafRecvSig;  /* I will set my leaf ranks's RecvSig */
621     dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
622   } else {                           /* LEAF2ROOT */
623     nsrcranks = bas->nRemoteLeafRanks;
624 
625     ndstranks  = sf->nRemoteRootRanks;
626     dstranks   = sf->ranks_d;
627     dstsig     = link->rootRecvSig;
628     dstsigdisp = sf->rootsigdisp_d;
629   }
630 
631   if (nsrcranks || ndstranks) {
632     PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
633     PetscCallCUDA(cudaGetLastError());
634   }
635   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
636   PetscFunctionReturn(PETSC_SUCCESS);
637 }
638 
639 /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
640 static PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
641 {
642   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
643   uint64_t      *srcsig;
644   PetscInt       nsrcranks, *srcsigdisp_d;
645   PetscMPIInt   *srcranks_d;
646 
647   PetscFunctionBegin;
648   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
649     nsrcranks    = sf->nRemoteRootRanks;
650     srcsig       = link->rootSendSig; /* I want to set their send signals */
651     srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
652     srcranks_d   = sf->ranks_d;       /* ranks of the n root ranks */
653   } else {                            /* LEAF2ROOT */
654     nsrcranks    = bas->nRemoteLeafRanks;
655     srcsig       = link->leafSendSig;
656     srcsigdisp_d = bas->leafsigdisp_d;
657     srcranks_d   = bas->iranks_d;
658   }
659 
660   if (nsrcranks) {
661     NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
662     PetscCallCUDA(cudaGetLastError());
663   }
664   PetscFunctionReturn(PETSC_SUCCESS);
665 }
666 
667 /* Destructor when the link uses nvshmem for communication */
668 static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link)
669 {
670   cudaError_t cerr;
671 
672   PetscFunctionBegin;
673   PetscCallCUDA(cudaEventDestroy(link->dataReady));
674   PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
675   PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));
676 
677   /* nvshmem does not need buffers on host, which should be NULL */
678   PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
679   PetscCall(PetscNvshmemFree(link->leafSendSig));
680   PetscCall(PetscNvshmemFree(link->leafRecvSig));
681   PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
682   PetscCall(PetscNvshmemFree(link->rootSendSig));
683   PetscCall(PetscNvshmemFree(link->rootRecvSig));
684   PetscFunctionReturn(PETSC_SUCCESS);
685 }
686 
687 PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink)
688 {
689   cudaError_t    cerr;
690   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
691   PetscSFLink   *p, link;
692   PetscBool      match, rootdirect[2], leafdirect[2];
693   int            greatestPriority;
694 
695   PetscFunctionBegin;
696   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
697      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermediate buffers in local communication with NVSHMEM.
698   */
699   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
700     if (sf->use_nvshmem_get) {
701       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
702       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
703     } else {
704       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
705       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
706     }
707   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
708     if (sf->use_nvshmem_get) {
709       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
710       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
711     } else {
712       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
713       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
714     }
715   } else {                                    /* PETSCSF_FETCH */
716     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
717     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
718   }
719 
720   /* Look for free nvshmem links in cache */
721   for (p = &bas->avail; (link = *p); p = &link->next) {
722     if (link->use_nvshmem) {
723       PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match));
724       if (match) {
725         *p = link->next; /* Remove from available list */
726         goto found;
727       }
728     }
729   }
730   PetscCall(PetscNew(&link));
731   PetscCall(PetscSFLinkSetUp_Host(sf, link, unit));                                          /* Compute link->unitbytes, dup link->unit etc. */
732   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf, link, unit)); /* Setup pack routines, streams etc */
733 #if defined(PETSC_HAVE_KOKKOS)
734   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf, link, unit));
735 #endif
736 
737   link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
738   link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;
739 
740   /* Init signals to zero */
741   if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig));
742   if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig));
743   if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig));
744   if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig));
745 
746   link->use_nvshmem = PETSC_TRUE;
747   link->rootmtype   = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
748   link->leafmtype   = PETSC_MEMTYPE_DEVICE;
749   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
750   link->Destroy = PetscSFLinkDestroy_NVSHMEM;
751   if (sf->use_nvshmem_get) { /* get-based protocol */
752     link->PrePack             = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
753     link->StartCommunication  = PetscSFLinkGetDataBegin_NVSHMEM;
754     link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
755   } else { /* put-based protocol */
756     link->StartCommunication  = PetscSFLinkPutDataBegin_NVSHMEM;
757     link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
758     link->PostUnpack          = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
759   }
760 
761   PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority));
762   PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority));
763 
764   PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming));
765   PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming));
766 
767 found:
768   if (rootdirect[PETSCSF_REMOTE]) {
769     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
770   } else {
771     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
772     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
773   }
774 
775   if (leafdirect[PETSCSF_REMOTE]) {
776     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
777   } else {
778     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
779     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
780   }
781 
782   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
783   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
784   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
785   link->leafdata                   = leafdata;
786   link->next                       = bas->inuse;
787   bas->inuse                       = link;
788   *mylink                          = link;
789   PetscFunctionReturn(PETSC_SUCCESS);
790 }
791 
792 #if defined(PETSC_USE_REAL_SINGLE)
793 PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src)
794 {
795   PetscMPIInt num; /* Assume nvshmem's int is MPI's int */
796 
797   PetscFunctionBegin;
798   PetscCall(PetscMPIIntCast(count, &num));
799   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
800   PetscFunctionReturn(PETSC_SUCCESS);
801 }
802 
803 PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src)
804 {
805   PetscMPIInt num;
806 
807   PetscFunctionBegin;
808   PetscCall(PetscMPIIntCast(count, &num));
809   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
810   PetscFunctionReturn(PETSC_SUCCESS);
811 }
812 #elif defined(PETSC_USE_REAL_DOUBLE)
813 PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src)
814 {
815   PetscMPIInt num;
816 
817   PetscFunctionBegin;
818   PetscCall(PetscMPIIntCast(count, &num));
819   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
820   PetscFunctionReturn(PETSC_SUCCESS);
821 }
822 
823 PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src)
824 {
825   PetscMPIInt num;
826 
827   PetscFunctionBegin;
828   PetscCall(PetscMPIIntCast(count, &num));
829   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
830   PetscFunctionReturn(PETSC_SUCCESS);
831 }
832 #endif
833