xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <petsc/private/cudavecimpl.h>
2 #include <../src/vec/is/sf/impls/basic/sfpack.h>
3 #include <mpi.h>
4 #include <nvshmem.h>
5 #include <nvshmemx.h>
6 
7 PetscErrorCode PetscNvshmemInitializeCheck(void) {
8   PetscFunctionBegin;
9   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
10     nvshmemx_init_attr_t attr;
11     attr.mpi_comm = &PETSC_COMM_WORLD;
12     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
13     PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr));
14     PetscNvshmemInitialized = PETSC_TRUE;
15     PetscBeganNvshmem       = PETSC_TRUE;
16   }
17   PetscFunctionReturn(0);
18 }
19 
20 PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr) {
21   PetscFunctionBegin;
22   PetscCall(PetscNvshmemInitializeCheck());
23   *ptr = nvshmem_malloc(size);
24   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_malloc() failed to allocate %zu bytes", size);
25   PetscFunctionReturn(0);
26 }
27 
28 PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr) {
29   PetscFunctionBegin;
30   PetscCall(PetscNvshmemInitializeCheck());
31   *ptr = nvshmem_calloc(size, 1);
32   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_calloc() failed to allocate %zu bytes", size);
33   PetscFunctionReturn(0);
34 }
35 
36 PetscErrorCode PetscNvshmemFree_Private(void *ptr) {
37   PetscFunctionBegin;
38   nvshmem_free(ptr);
39   PetscFunctionReturn(0);
40 }
41 
42 PetscErrorCode PetscNvshmemFinalize(void) {
43   PetscFunctionBegin;
44   nvshmem_finalize();
45   PetscFunctionReturn(0);
46 }
47 
48 /* Free nvshmem related fields in the SF */
49 PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf) {
50   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
51 
52   PetscFunctionBegin;
53   PetscCall(PetscFree2(bas->leafsigdisp, bas->leafbufdisp));
54   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d));
55   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d));
56   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d));
57   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d));
58 
59   PetscCall(PetscFree2(sf->rootsigdisp, sf->rootbufdisp));
60   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d));
61   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d));
62   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d));
63   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d));
64   PetscFunctionReturn(0);
65 }
66 
67 /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
68 static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf) {
69   cudaError_t    cerr;
70   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
71   PetscInt       i, nRemoteRootRanks, nRemoteLeafRanks;
72   PetscMPIInt    tag;
73   MPI_Comm       comm;
74   MPI_Request   *rootreqs, *leafreqs;
75   PetscInt       tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */
76 
77   PetscFunctionBegin;
78   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
79   PetscCall(PetscObjectGetNewTag((PetscObject)sf, &tag));
80 
81   nRemoteRootRanks      = sf->nranks - sf->ndranks;
82   nRemoteLeafRanks      = bas->niranks - bas->ndiranks;
83   sf->nRemoteRootRanks  = nRemoteRootRanks;
84   bas->nRemoteLeafRanks = nRemoteLeafRanks;
85 
86   PetscCall(PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs));
87 
88   stmp[0] = nRemoteRootRanks;
89   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
90   stmp[2] = nRemoteLeafRanks;
91   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
92 
93   PetscCall(MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm));
94 
95   sf->nRemoteRootRanksMax  = rtmp[0];
96   sf->leafbuflen_rmax      = rtmp[1];
97   bas->nRemoteLeafRanksMax = rtmp[2];
98   bas->rootbuflen_rmax     = rtmp[3];
99 
100   /* Total four rounds of MPI communications to set up the nvshmem fields */
101 
102   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
103   PetscCall(PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp));
104   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
105   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm));                             /* Roots send. Note i changes, so we use MPI_Send. */
106   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
107 
108   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
109   for (i = 0; i < nRemoteLeafRanks; i++) {
110     tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
111     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note tmp changes, so we use MPI_Send. */
112   }
113   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
114 
115   PetscCallCUDA(cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
116   PetscCallCUDA(cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
117   PetscCallCUDA(cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt)));
118   PetscCallCUDA(cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt)));
119 
120   PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
121   PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
122   PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
123   PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
124 
125   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
126   PetscCall(PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp));
127   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
128   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
129   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
130 
131   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
132   for (i = 0; i < nRemoteRootRanks; i++) {
133     tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
134     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
135   }
136   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
137 
138   PetscCallCUDA(cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
139   PetscCallCUDA(cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
140   PetscCallCUDA(cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt)));
141   PetscCallCUDA(cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt)));
142 
143   PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
144   PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
145   PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
146   PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
147 
148   PetscCall(PetscFree2(rootreqs, leafreqs));
149   PetscFunctionReturn(0);
150 }
151 
152 PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem) {
153   MPI_Comm    comm;
154   PetscBool   isBasic;
155   PetscMPIInt result = MPI_UNEQUAL;
156 
157   PetscFunctionBegin;
158   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
159   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
160      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
161   */
162   sf->checked_nvshmem_eligibility = PETSC_TRUE;
163   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
164     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
165     PetscCall(PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic));
166     if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result));
167     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
168 
169     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
170        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
171        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
172     */
173     if (sf->use_nvshmem) {
174       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
175       PetscCallMPI(MPI_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm));
176       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
177     }
178     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
179   }
180 
181   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
182   if (sf->use_nvshmem) {
183     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
184     PetscInt allCuda = oneCuda;                                                                                          /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
185 #if defined(PETSC_USE_DEBUG)                                                                                             /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
186     PetscCallMPI(MPI_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm));
187     PetscCheck(allCuda == oneCuda, comm, PETSC_ERR_SUP, "root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
188 #endif
189     if (allCuda) {
190       PetscCall(PetscNvshmemInitializeCheck());
191       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
192         PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
193         sf->setup_nvshmem = PETSC_TRUE;
194       }
195       *use_nvshmem = PETSC_TRUE;
196     } else {
197       *use_nvshmem = PETSC_FALSE;
198     }
199   } else {
200     *use_nvshmem = PETSC_FALSE;
201   }
202   PetscFunctionReturn(0);
203 }
204 
205 /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
206 static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
207   cudaError_t    cerr;
208   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
209   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
210 
211   PetscFunctionBegin;
212   if (buflen) {
213     PetscCallCUDA(cudaEventRecord(link->dataReady, link->stream));
214     PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0));
215   }
216   PetscFunctionReturn(0);
217 }
218 
219 /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
220 static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
221   cudaError_t    cerr;
222   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
223   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
224 
225   PetscFunctionBegin;
226   /* If unpack to non-null device buffer, build the endRemoteComm dependance */
227   if (buflen) {
228     PetscCallCUDA(cudaEventRecord(link->endRemoteComm, link->remoteCommStream));
229     PetscCallCUDA(cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0));
230   }
231   PetscFunctionReturn(0);
232 }
233 
234 /* Send/Put signals to remote ranks
235 
236  Input parameters:
237   + n        - Number of remote ranks
238   . sig      - Signal address in symmetric heap
239   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
240   . ranks    - remote ranks
241   - newval   - Set signals to this value
242 */
243 __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval) {
244   int i = blockIdx.x * blockDim.x + threadIdx.x;
245 
246   /* Each thread puts one remote signal */
247   if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
248 }
249 
250 /* Wait until local signals equal to the expected value and then set them to a new value
251 
252  Input parameters:
253   + n        - Number of signals
254   . sig      - Local signal address
255   . expval   - expected value
256   - newval   - Set signals to this new value
257 */
258 __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval) {
259 #if 0
260   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
261   int i = blockIdx.x*blockDim.x + threadIdx.x;
262   if (i < n) {
263     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
264     sig[i] = newval;
265   }
266 #else
267   nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
268   for (int i = 0; i < n; i++) sig[i] = newval;
269 #endif
270 }
271 
272 /* ===========================================================================================================
273 
274    A set of routines to support receiver initiated communication using the get method
275 
276     The getting protocol is:
277 
278     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
279     All signal variables have an initial value 0.
280 
281     Sender:                                 |  Receiver:
282   1.  Wait ssig be 0, then set it to 1
283   2.  Pack data into stand alone sbuf       |
284   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
285                                             |   2. Get data from remote sbuf to local rbuf
286                                             |   3. Put 1 to sender's ssig
287                                             |   4. Unpack data from local rbuf
288    ===========================================================================================================*/
289 /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
290    Sender waits for signals (from receivers) indicating receivers have finished getting data
291 */
292 PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
293   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
294   uint64_t      *sig;
295   PetscInt       n;
296 
297   PetscFunctionBegin;
298   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
299     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
300     n   = bas->nRemoteLeafRanks;
301   } else { /* LEAF2ROOT */
302     sig = link->leafSendSig;
303     n   = sf->nRemoteRootRanks;
304   }
305 
306   if (n) {
307     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
308     PetscCallCUDA(cudaGetLastError());
309   }
310   PetscFunctionReturn(0);
311 }
312 
313 /* n thread blocks. Each takes in charge one remote rank */
314 __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes) {
315   int         bid = blockIdx.x;
316   PetscMPIInt pe  = srcranks[bid];
317 
318   if (!nvshmem_ptr(src, pe)) {
319     PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
320     nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
321   }
322 }
323 
324 /* Start communication -- Get data in the given direction */
325 PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
326   cudaError_t    cerr;
327   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
328 
329   PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;
330 
331   char        *src, *dst;
332   PetscInt    *srcdisp_h, *dstdisp_h;
333   PetscInt    *srcdisp_d, *dstdisp_d;
334   PetscMPIInt *srcranks_h;
335   PetscMPIInt *srcranks_d, *dstranks_d;
336   uint64_t    *dstsig;
337   PetscInt    *dstsigdisp_d;
338 
339   PetscFunctionBegin;
340   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
341   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
342     nsrcranks = sf->nRemoteRootRanks;
343     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
344 
345     srcdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
346     srcdisp_d  = sf->rootbufdisp_d;
347     srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
348     srcranks_d = sf->ranks_d;
349 
350     ndstranks = bas->nRemoteLeafRanks;
351     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
352 
353     dstdisp_h  = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
354     dstdisp_d  = sf->roffset_d;
355     dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */
356 
357     dstsig       = link->leafRecvSig;
358     dstsigdisp_d = bas->leafsigdisp_d;
359   } else { /* src is leaf, dst is root; we will move data from src to dst */
360     nsrcranks = bas->nRemoteLeafRanks;
361     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
362 
363     srcdisp_h  = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
364     srcdisp_d  = bas->leafbufdisp_d;
365     srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
366     srcranks_d = bas->iranks_d;
367 
368     ndstranks = sf->nRemoteRootRanks;
369     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
370 
371     dstdisp_h  = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
372     dstdisp_d  = bas->ioffset_d;
373     dstranks_d = sf->ranks_d; /* my (remote) root ranks */
374 
375     dstsig       = link->rootRecvSig;
376     dstsigdisp_d = sf->rootsigdisp_d;
377   }
378 
379   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
380   if (ndstranks) {
381     NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
382     PetscCallCUDA(cudaGetLastError());
383   }
384 
385   /* dst waits for signals (permissions) from src ranks to start getting data */
386   if (nsrcranks) {
387     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
388     PetscCallCUDA(cudaGetLastError());
389   }
390 
391   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
392 
393   /* Count number of locally accessible src ranks, which should be a small number */
394   for (int i = 0; i < nsrcranks; i++) {
395     if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
396   }
397 
398   /* Get data from remotely accessible PEs */
399   if (nLocallyAccessible < nsrcranks) {
400     GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
401     PetscCallCUDA(cudaGetLastError());
402   }
403 
404   /* Get data from locally accessible PEs */
405   if (nLocallyAccessible) {
406     for (int i = 0; i < nsrcranks; i++) {
407       int pe = srcranks_h[i];
408       if (nvshmem_ptr(src, pe)) {
409         size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
410         nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
411       }
412     }
413   }
414   PetscFunctionReturn(0);
415 }
416 
417 /* Finish the communication (can be done before Unpack)
418    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
419 */
420 PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
421   cudaError_t    cerr;
422   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
423   uint64_t      *srcsig;
424   PetscInt       nsrcranks, *srcsigdisp;
425   PetscMPIInt   *srcranks;
426 
427   PetscFunctionBegin;
428   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
429     nsrcranks  = sf->nRemoteRootRanks;
430     srcsig     = link->rootSendSig; /* I want to set their root signal */
431     srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
432     srcranks   = sf->ranks_d;       /* ranks of the n root ranks */
433   } else {                          /* LEAF2ROOT, root ranks are getting data */
434     nsrcranks  = bas->nRemoteLeafRanks;
435     srcsig     = link->leafSendSig;
436     srcsigdisp = bas->leafsigdisp_d;
437     srcranks   = bas->iranks_d;
438   }
439 
440   if (nsrcranks) {
441     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
442     PetscCallCUDA(cudaGetLastError());
443     NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
444     PetscCallCUDA(cudaGetLastError());
445   }
446   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
447   PetscFunctionReturn(0);
448 }
449 
450 /* ===========================================================================================================
451 
452    A set of routines to support sender initiated communication using the put-based method (the default)
453 
454     The putting protocol is:
455 
456     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
457     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
458     is in nvshmem space.
459 
460     Sender:                                 |  Receiver:
461                                             |
462   1.  Pack data into sbuf                   |
463   2.  Wait ssig be 0, then set it to 1      |
464   3.  Put data to remote stand-alone rbuf   |
465   4.  Fence // make sure 5 happens after 3  |
466   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
467                                             |   2. Unpack data from local rbuf
468                                             |   3. Put 0 to sender's ssig
469    ===========================================================================================================*/
470 
471 /* n thread blocks. Each takes in charge one remote rank */
472 __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes) {
473   int         bid = blockIdx.x;
474   PetscMPIInt pe  = dstranks[bid];
475 
476   if (!nvshmem_ptr(dst, pe)) {
477     PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
478     nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
479     srcsig[bid] = 1;
480     nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
481   }
482 }
483 
484 /* one-thread kernel, which takes in charge all locally accesible */
485 __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst) {
486   for (int i = 0; i < ndstranks; i++) {
487     int pe = dstranks[i];
488     if (nvshmem_ptr(dst, pe)) {
489       nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
490       srcsig[i] = 1;
491     }
492   }
493 }
494 
495 /* Put data in the given direction  */
496 PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
497   cudaError_t    cerr;
498   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
499   PetscInt       ndstranks, nLocallyAccessible = 0;
500   char          *src, *dst;
501   PetscInt      *srcdisp_h, *dstdisp_h;
502   PetscInt      *srcdisp_d, *dstdisp_d;
503   PetscMPIInt   *dstranks_h;
504   PetscMPIInt   *dstranks_d;
505   uint64_t      *srcsig;
506 
507   PetscFunctionBegin;
508   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
509   if (direction == PETSCSF_ROOT2LEAF) {                              /* put data in rootbuf to leafbuf  */
510     ndstranks = bas->nRemoteLeafRanks;                               /* number of (remote) leaf ranks */
511     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
512     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
513 
514     srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
515     srcdisp_d = bas->ioffset_d;
516     srcsig    = link->rootSendSig;
517 
518     dstdisp_h  = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
519     dstdisp_d  = bas->leafbufdisp_d;
520     dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
521     dstranks_d = bas->iranks_d;
522   } else { /* put data in leafbuf to rootbuf */
523     ndstranks = sf->nRemoteRootRanks;
524     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
525     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
526 
527     srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
528     srcdisp_d = sf->roffset_d;
529     srcsig    = link->leafSendSig;
530 
531     dstdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
532     dstdisp_d  = sf->rootbufdisp_d;
533     dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
534     dstranks_d = sf->ranks_d;
535   }
536 
537   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
538 
539   /* Count number of locally accessible neighbors, which should be a small number */
540   for (int i = 0; i < ndstranks; i++) {
541     if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
542   }
543 
544   /* For remotely accessible PEs, send data to them in one kernel call */
545   if (nLocallyAccessible < ndstranks) {
546     WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
547     PetscCallCUDA(cudaGetLastError());
548   }
549 
550   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
551   if (nLocallyAccessible) {
552     WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
553     for (int i = 0; i < ndstranks; i++) {
554       int pe = dstranks_h[i];
555       if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
556         size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
557         /* Initiate the nonblocking communication */
558         nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
559       }
560     }
561   }
562 
563   if (nLocallyAccessible) { nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */ }
564   PetscFunctionReturn(0);
565 }
566 
567 /* A one-thread kernel. The thread takes in charge all remote PEs */
568 __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp) {
569   /* TODO: Shall we finished the non-blocking remote puts? */
570 
571   /* 1. Send a signal to each dst rank */
572 
573   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
574      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
575   */
576   for (int i = 0; i < ndstranks; i++) { nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); } /* set sig to 1 */
577 
578   /* 2. Wait for signals from src ranks (if any) */
579   if (nsrcranks) {
580     nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
581     for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
582   }
583 }
584 
585 /* Finish the communication -- A receiver waits until it can access its receive buffer */
586 PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
587   cudaError_t    cerr;
588   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
589   PetscMPIInt   *dstranks;
590   uint64_t      *dstsig;
591   PetscInt       nsrcranks, ndstranks, *dstsigdisp;
592 
593   PetscFunctionBegin;
594   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
595     nsrcranks = sf->nRemoteRootRanks;
596 
597     ndstranks  = bas->nRemoteLeafRanks;
598     dstranks   = bas->iranks_d;      /* leaf ranks */
599     dstsig     = link->leafRecvSig;  /* I will set my leaf ranks's RecvSig */
600     dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
601   } else {                           /* LEAF2ROOT */
602     nsrcranks = bas->nRemoteLeafRanks;
603 
604     ndstranks  = sf->nRemoteRootRanks;
605     dstranks   = sf->ranks_d;
606     dstsig     = link->rootRecvSig;
607     dstsigdisp = sf->rootsigdisp_d;
608   }
609 
610   if (nsrcranks || ndstranks) {
611     PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
612     PetscCallCUDA(cudaGetLastError());
613   }
614   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
615   PetscFunctionReturn(0);
616 }
617 
618 /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
619 PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction) {
620   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
621   uint64_t      *srcsig;
622   PetscInt       nsrcranks, *srcsigdisp_d;
623   PetscMPIInt   *srcranks_d;
624 
625   PetscFunctionBegin;
626   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
627     nsrcranks    = sf->nRemoteRootRanks;
628     srcsig       = link->rootSendSig; /* I want to set their send signals */
629     srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
630     srcranks_d   = sf->ranks_d;       /* ranks of the n root ranks */
631   } else {                            /* LEAF2ROOT */
632     nsrcranks    = bas->nRemoteLeafRanks;
633     srcsig       = link->leafSendSig;
634     srcsigdisp_d = bas->leafsigdisp_d;
635     srcranks_d   = bas->iranks_d;
636   }
637 
638   if (nsrcranks) {
639     NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
640     PetscCallCUDA(cudaGetLastError());
641   }
642   PetscFunctionReturn(0);
643 }
644 
645 /* Destructor when the link uses nvshmem for communication */
646 static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link) {
647   cudaError_t cerr;
648 
649   PetscFunctionBegin;
650   PetscCallCUDA(cudaEventDestroy(link->dataReady));
651   PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
652   PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));
653 
654   /* nvshmem does not need buffers on host, which should be NULL */
655   PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
656   PetscCall(PetscNvshmemFree(link->leafSendSig));
657   PetscCall(PetscNvshmemFree(link->leafRecvSig));
658   PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
659   PetscCall(PetscNvshmemFree(link->rootSendSig));
660   PetscCall(PetscNvshmemFree(link->rootRecvSig));
661   PetscFunctionReturn(0);
662 }
663 
664 PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink) {
665   cudaError_t    cerr;
666   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
667   PetscSFLink   *p, link;
668   PetscBool      match, rootdirect[2], leafdirect[2];
669   int            greatestPriority;
670 
671   PetscFunctionBegin;
672   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
673      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM.
674   */
675   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
676     if (sf->use_nvshmem_get) {
677       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
678       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
679     } else {
680       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
681       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
682     }
683   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
684     if (sf->use_nvshmem_get) {
685       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
686       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
687     } else {
688       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
689       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
690     }
691   } else {                                    /* PETSCSF_FETCH */
692     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
693     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
694   }
695 
696   /* Look for free nvshmem links in cache */
697   for (p = &bas->avail; (link = *p); p = &link->next) {
698     if (link->use_nvshmem) {
699       PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match));
700       if (match) {
701         *p = link->next; /* Remove from available list */
702         goto found;
703       }
704     }
705   }
706   PetscCall(PetscNew(&link));
707   PetscCall(PetscSFLinkSetUp_Host(sf, link, unit));                                          /* Compute link->unitbytes, dup link->unit etc. */
708   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf, link, unit)); /* Setup pack routines, streams etc */
709 #if defined(PETSC_HAVE_KOKKOS)
710   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf, link, unit));
711 #endif
712 
713   link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
714   link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;
715 
716   /* Init signals to zero */
717   if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig));
718   if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig));
719   if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig));
720   if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig));
721 
722   link->use_nvshmem = PETSC_TRUE;
723   link->rootmtype   = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
724   link->leafmtype   = PETSC_MEMTYPE_DEVICE;
725   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
726   link->Destroy     = PetscSFLinkDestroy_NVSHMEM;
727   if (sf->use_nvshmem_get) { /* get-based protocol */
728     link->PrePack             = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
729     link->StartCommunication  = PetscSFLinkGetDataBegin_NVSHMEM;
730     link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
731   } else { /* put-based protocol */
732     link->StartCommunication  = PetscSFLinkPutDataBegin_NVSHMEM;
733     link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
734     link->PostUnpack          = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
735   }
736 
737   PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority));
738   PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority));
739 
740   PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming));
741   PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming));
742 
743 found:
744   if (rootdirect[PETSCSF_REMOTE]) {
745     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
746   } else {
747     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) { PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE])); }
748     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
749   }
750 
751   if (leafdirect[PETSCSF_REMOTE]) {
752     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
753   } else {
754     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) { PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE])); }
755     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
756   }
757 
758   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
759   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
760   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
761   link->leafdata                   = leafdata;
762   link->next                       = bas->inuse;
763   bas->inuse                       = link;
764   *mylink                          = link;
765   PetscFunctionReturn(0);
766 }
767 
768 #if defined(PETSC_USE_REAL_SINGLE)
769 PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src) {
770   PetscMPIInt num; /* Assume nvshmem's int is MPI's int */
771 
772   PetscFunctionBegin;
773   PetscCall(PetscMPIIntCast(count, &num));
774   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
775   PetscFunctionReturn(0);
776 }
777 
778 PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src) {
779   PetscMPIInt num;
780 
781   PetscFunctionBegin;
782   PetscCall(PetscMPIIntCast(count, &num));
783   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
784   PetscFunctionReturn(0);
785 }
786 #elif defined(PETSC_USE_REAL_DOUBLE)
787 PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src) {
788   PetscMPIInt num;
789 
790   PetscFunctionBegin;
791   PetscCall(PetscMPIIntCast(count, &num));
792   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
793   PetscFunctionReturn(0);
794 }
795 
796 PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src) {
797   PetscMPIInt num;
798 
799   PetscFunctionBegin;
800   PetscCall(PetscMPIIntCast(count, &num));
801   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
802   PetscFunctionReturn(0);
803 }
804 #endif
805