1 #include <petsc/private/cudavecimpl.h>
2 #include <../src/vec/is/sf/impls/basic/sfpack.h>
3 #include <mpi.h>
4 #include <nvshmem.h>
5 #include <nvshmemx.h>
6
PetscNvshmemInitializeCheck(void)7 PetscErrorCode PetscNvshmemInitializeCheck(void)
8 {
9 PetscFunctionBegin;
10 if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
11 nvshmemx_init_attr_t attr;
12 attr.mpi_comm = &PETSC_COMM_WORLD;
13 PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
14 PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr));
15 PetscNvshmemInitialized = PETSC_TRUE;
16 PetscBeganNvshmem = PETSC_TRUE;
17 }
18 PetscFunctionReturn(PETSC_SUCCESS);
19 }
20
PetscNvshmemMalloc(size_t size,void ** ptr)21 PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr)
22 {
23 PetscFunctionBegin;
24 PetscCall(PetscNvshmemInitializeCheck());
25 *ptr = nvshmem_malloc(size);
26 PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_malloc() failed to allocate %zu bytes", size);
27 PetscFunctionReturn(PETSC_SUCCESS);
28 }
29
PetscNvshmemCalloc(size_t size,void ** ptr)30 PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr)
31 {
32 PetscFunctionBegin;
33 PetscCall(PetscNvshmemInitializeCheck());
34 *ptr = nvshmem_calloc(size, 1);
35 PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_calloc() failed to allocate %zu bytes", size);
36 PetscFunctionReturn(PETSC_SUCCESS);
37 }
38
PetscNvshmemFree_Private(void * ptr)39 PetscErrorCode PetscNvshmemFree_Private(void *ptr)
40 {
41 PetscFunctionBegin;
42 nvshmem_free(ptr);
43 PetscFunctionReturn(PETSC_SUCCESS);
44 }
45
PetscNvshmemFinalize(void)46 PetscErrorCode PetscNvshmemFinalize(void)
47 {
48 PetscFunctionBegin;
49 nvshmem_finalize();
50 PetscFunctionReturn(PETSC_SUCCESS);
51 }
52
53 /* Free nvshmem related fields in the SF */
PetscSFReset_Basic_NVSHMEM(PetscSF sf)54 PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
55 {
56 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
57
58 PetscFunctionBegin;
59 PetscCall(PetscFree2(bas->leafsigdisp, bas->leafbufdisp));
60 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d));
61 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d));
62 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d));
63 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d));
64
65 PetscCall(PetscFree2(sf->rootsigdisp, sf->rootbufdisp));
66 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d));
67 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d));
68 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d));
69 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d));
70 PetscFunctionReturn(PETSC_SUCCESS);
71 }
72
73 /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependent fields) */
PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)74 static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
75 {
76 cudaError_t cerr;
77 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
78 PetscInt i, nRemoteRootRanks, nRemoteLeafRanks;
79 PetscMPIInt tag;
80 MPI_Comm comm;
81 MPI_Request *rootreqs, *leafreqs;
82 PetscInt tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */
83
84 PetscFunctionBegin;
85 PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
86 PetscCall(PetscObjectGetNewTag((PetscObject)sf, &tag));
87
88 nRemoteRootRanks = sf->nranks - sf->ndranks;
89 nRemoteLeafRanks = bas->niranks - bas->ndiranks;
90 sf->nRemoteRootRanks = nRemoteRootRanks;
91 bas->nRemoteLeafRanks = nRemoteLeafRanks;
92
93 PetscCall(PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs));
94
95 stmp[0] = nRemoteRootRanks;
96 stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
97 stmp[2] = nRemoteLeafRanks;
98 stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
99
100 PetscCallMPI(MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm));
101
102 sf->nRemoteRootRanksMax = rtmp[0];
103 sf->leafbuflen_rmax = rtmp[1];
104 bas->nRemoteLeafRanksMax = rtmp[2];
105 bas->rootbuflen_rmax = rtmp[3];
106
107 /* Total four rounds of MPI communications to set up the nvshmem fields */
108
109 /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
110 PetscCall(PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp));
111 for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
112 for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note i changes, so we use MPI_Send. */
113 PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
114
115 for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
116 for (i = 0; i < nRemoteLeafRanks; i++) {
117 tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
118 PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note tmp changes, so we use MPI_Send. */
119 }
120 PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
121
122 PetscCallCUDA(cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
123 PetscCallCUDA(cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
124 PetscCallCUDA(cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt)));
125 PetscCallCUDA(cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt)));
126
127 PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
128 PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
129 PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
130 PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
131
132 /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
133 PetscCall(PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp));
134 for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
135 for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
136 PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
137
138 for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
139 for (i = 0; i < nRemoteRootRanks; i++) {
140 tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
141 PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
142 }
143 PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
144
145 PetscCallCUDA(cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
146 PetscCallCUDA(cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
147 PetscCallCUDA(cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt)));
148 PetscCallCUDA(cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt)));
149
150 PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
151 PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
152 PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
153 PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
154
155 PetscCall(PetscFree2(rootreqs, leafreqs));
156 PetscFunctionReturn(PETSC_SUCCESS);
157 }
158
PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void * rootdata,PetscMemType leafmtype,const void * leafdata,PetscBool * use_nvshmem)159 PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem)
160 {
161 MPI_Comm comm;
162 PetscBool isBasic;
163 PetscMPIInt result = MPI_UNEQUAL;
164
165 PetscFunctionBegin;
166 PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
167 /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
168 Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
169 */
170 sf->checked_nvshmem_eligibility = PETSC_TRUE;
171 if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
172 /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD */
173 PetscCall(PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic));
174 if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result));
175 if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
176
177 /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
178 and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
179 inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
180 */
181 if (sf->use_nvshmem) {
182 PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
183 PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm));
184 if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
185 }
186 sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
187 }
188
189 /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
190 if (sf->use_nvshmem) {
191 PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
192 PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
193 #if defined(PETSC_USE_DEBUG) /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
194 PetscCallMPI(MPIU_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm));
195 PetscCheck(allCuda == oneCuda, comm, PETSC_ERR_SUP, "root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
196 #endif
197 if (allCuda) {
198 PetscCall(PetscNvshmemInitializeCheck());
199 if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
200 PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
201 sf->setup_nvshmem = PETSC_TRUE;
202 }
203 *use_nvshmem = PETSC_TRUE;
204 } else {
205 *use_nvshmem = PETSC_FALSE;
206 }
207 } else {
208 *use_nvshmem = PETSC_FALSE;
209 }
210 PetscFunctionReturn(PETSC_SUCCESS);
211 }
212
213 /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)214 static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
215 {
216 cudaError_t cerr;
217 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
218 PetscInt buflen = (direction == PETSCSF_ROOT2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
219
220 PetscFunctionBegin;
221 if (buflen) {
222 PetscCallCUDA(cudaEventRecord(link->dataReady, link->stream));
223 PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0));
224 }
225 PetscFunctionReturn(PETSC_SUCCESS);
226 }
227
228 /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)229 static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
230 {
231 cudaError_t cerr;
232 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
233 PetscInt buflen = (direction == PETSCSF_ROOT2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
234
235 PetscFunctionBegin;
236 /* If unpack to non-null device buffer, build the endRemoteComm dependence */
237 if (buflen) {
238 PetscCallCUDA(cudaEventRecord(link->endRemoteComm, link->remoteCommStream));
239 PetscCallCUDA(cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0));
240 }
241 PetscFunctionReturn(PETSC_SUCCESS);
242 }
243
244 /* Send/Put signals to remote ranks
245
246 Input parameters:
247 + n - Number of remote ranks
248 . sig - Signal address in symmetric heap
249 . sigdisp - To i-th rank, use its signal at offset sigdisp[i]
250 . ranks - remote ranks
251 - newval - Set signals to this value
252 */
NvshmemSendSignals(PetscInt n,uint64_t * sig,PetscInt * sigdisp,PetscMPIInt * ranks,uint64_t newval)253 __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval)
254 {
255 int i = blockIdx.x * blockDim.x + threadIdx.x;
256
257 /* Each thread puts one remote signal */
258 if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
259 }
260
261 /* Wait until local signals equal to the expected value and then set them to a new value
262
263 Input parameters:
264 + n - Number of signals
265 . sig - Local signal address
266 . expval - expected value
267 - newval - Set signals to this new value
268 */
NvshmemWaitSignals(PetscInt n,uint64_t * sig,uint64_t expval,uint64_t newval)269 __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval)
270 {
271 #if 0
272 /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
273 int i = blockIdx.x*blockDim.x + threadIdx.x;
274 if (i < n) {
275 nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
276 sig[i] = newval;
277 }
278 #else
279 nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
280 for (int i = 0; i < n; i++) sig[i] = newval;
281 #endif
282 }
283
284 /* ===========================================================================================================
285
286 A set of routines to support receiver initiated communication using the get method
287
288 The getting protocol is:
289
290 Sender has a send buf (sbuf) and a signal variable (ssig); Receiver has a recv buf (rbuf) and a signal variable (rsig);
291 All signal variables have an initial value 0.
292
293 Sender: | Receiver:
294 1. Wait ssig be 0, then set it to 1
295 2. Pack data into stand alone sbuf |
296 3. Put 1 to receiver's rsig | 1. Wait rsig to be 1, then set it 0
297 | 2. Get data from remote sbuf to local rbuf
298 | 3. Put 1 to sender's ssig
299 | 4. Unpack data from local rbuf
300 ===========================================================================================================*/
301 /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
302 Sender waits for signals (from receivers) indicating receivers have finished getting data
303 */
PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)304 static PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
305 {
306 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
307 uint64_t *sig;
308 PetscInt n;
309
310 PetscFunctionBegin;
311 if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
312 sig = link->rootSendSig; /* leaf ranks set my rootSendsig */
313 n = bas->nRemoteLeafRanks;
314 } else { /* LEAF2ROOT */
315 sig = link->leafSendSig;
316 n = sf->nRemoteRootRanks;
317 }
318
319 if (n) {
320 NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
321 PetscCallCUDA(cudaGetLastError());
322 }
323 PetscFunctionReturn(PETSC_SUCCESS);
324 }
325
326 /* n thread blocks. Each takes in charge one remote rank */
GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt * srcranks,const char * src,PetscInt * srcdisp,char * dst,PetscInt * dstdisp,PetscInt unitbytes)327 __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes)
328 {
329 int bid = blockIdx.x;
330 PetscMPIInt pe = srcranks[bid];
331
332 if (!nvshmem_ptr(src, pe)) {
333 PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
334 nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
335 }
336 }
337
338 /* Start communication -- Get data in the given direction */
PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)339 static PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
340 {
341 cudaError_t cerr;
342 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
343
344 PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;
345
346 char *src, *dst;
347 PetscInt *srcdisp_h, *dstdisp_h;
348 PetscInt *srcdisp_d, *dstdisp_d;
349 PetscMPIInt *srcranks_h;
350 PetscMPIInt *srcranks_d, *dstranks_d;
351 uint64_t *dstsig;
352 PetscInt *dstsigdisp_d;
353
354 PetscFunctionBegin;
355 PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
356 if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
357 nsrcranks = sf->nRemoteRootRanks;
358 src = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
359
360 srcdisp_h = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
361 srcdisp_d = sf->rootbufdisp_d;
362 srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
363 srcranks_d = sf->ranks_d;
364
365 ndstranks = bas->nRemoteLeafRanks;
366 dst = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
367
368 dstdisp_h = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
369 dstdisp_d = sf->roffset_d;
370 dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */
371
372 dstsig = link->leafRecvSig;
373 dstsigdisp_d = bas->leafsigdisp_d;
374 } else { /* src is leaf, dst is root; we will move data from src to dst */
375 nsrcranks = bas->nRemoteLeafRanks;
376 src = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
377
378 srcdisp_h = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
379 srcdisp_d = bas->leafbufdisp_d;
380 srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
381 srcranks_d = bas->iranks_d;
382
383 ndstranks = sf->nRemoteRootRanks;
384 dst = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
385
386 dstdisp_h = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
387 dstdisp_d = bas->ioffset_d;
388 dstranks_d = sf->ranks_d; /* my (remote) root ranks */
389
390 dstsig = link->rootRecvSig;
391 dstsigdisp_d = sf->rootsigdisp_d;
392 }
393
394 /* After Pack operation -- src tells dst ranks that they are allowed to get data */
395 if (ndstranks) {
396 NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
397 PetscCallCUDA(cudaGetLastError());
398 }
399
400 /* dst waits for signals (permissions) from src ranks to start getting data */
401 if (nsrcranks) {
402 NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
403 PetscCallCUDA(cudaGetLastError());
404 }
405
406 /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
407
408 /* Count number of locally accessible src ranks, which should be a small number */
409 for (int i = 0; i < nsrcranks; i++) {
410 if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
411 }
412
413 /* Get data from remotely accessible PEs */
414 if (nLocallyAccessible < nsrcranks) {
415 GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
416 PetscCallCUDA(cudaGetLastError());
417 }
418
419 /* Get data from locally accessible PEs */
420 if (nLocallyAccessible) {
421 for (int i = 0; i < nsrcranks; i++) {
422 int pe = srcranks_h[i];
423 if (nvshmem_ptr(src, pe)) {
424 size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
425 nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
426 }
427 }
428 }
429 PetscFunctionReturn(PETSC_SUCCESS);
430 }
431
432 /* Finish the communication (can be done before Unpack)
433 Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
434 */
PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)435 static PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
436 {
437 cudaError_t cerr;
438 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
439 uint64_t *srcsig;
440 PetscInt nsrcranks, *srcsigdisp;
441 PetscMPIInt *srcranks;
442
443 PetscFunctionBegin;
444 if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
445 nsrcranks = sf->nRemoteRootRanks;
446 srcsig = link->rootSendSig; /* I want to set their root signal */
447 srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
448 srcranks = sf->ranks_d; /* ranks of the n root ranks */
449 } else { /* LEAF2ROOT, root ranks are getting data */
450 nsrcranks = bas->nRemoteLeafRanks;
451 srcsig = link->leafSendSig;
452 srcsigdisp = bas->leafsigdisp_d;
453 srcranks = bas->iranks_d;
454 }
455
456 if (nsrcranks) {
457 nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
458 PetscCallCUDA(cudaGetLastError());
459 NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
460 PetscCallCUDA(cudaGetLastError());
461 }
462 PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
463 PetscFunctionReturn(PETSC_SUCCESS);
464 }
465
466 /* ===========================================================================================================
467
468 A set of routines to support sender initiated communication using the put-based method (the default)
469
470 The putting protocol is:
471
472 Sender has a send buf (sbuf) and a send signal var (ssig); Receiver has a stand-alone recv buf (rbuf)
473 and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
474 is in nvshmem space.
475
476 Sender: | Receiver:
477 |
478 1. Pack data into sbuf |
479 2. Wait ssig be 0, then set it to 1 |
480 3. Put data to remote stand-alone rbuf |
481 4. Fence // make sure 5 happens after 3 |
482 5. Put 1 to receiver's rsig | 1. Wait rsig to be 1, then set it 0
483 | 2. Unpack data from local rbuf
484 | 3. Put 0 to sender's ssig
485 ===========================================================================================================*/
486
487 /* n thread blocks. Each takes in charge one remote rank */
WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt * dstranks,char * dst,PetscInt * dstdisp,const char * src,PetscInt * srcdisp,uint64_t * srcsig,PetscInt unitbytes)488 __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes)
489 {
490 int bid = blockIdx.x;
491 PetscMPIInt pe = dstranks[bid];
492
493 if (!nvshmem_ptr(dst, pe)) {
494 PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
495 nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
496 srcsig[bid] = 1;
497 nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
498 }
499 }
500
501 /* one-thread kernel, which takes in charge all locally accessible */
WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt * dstranks,uint64_t * srcsig,const char * dst)502 __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst)
503 {
504 for (int i = 0; i < ndstranks; i++) {
505 int pe = dstranks[i];
506 if (nvshmem_ptr(dst, pe)) {
507 nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
508 srcsig[i] = 1;
509 }
510 }
511 }
512
513 /* Put data in the given direction */
PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)514 static PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
515 {
516 cudaError_t cerr;
517 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
518 PetscInt ndstranks, nLocallyAccessible = 0;
519 char *src, *dst;
520 PetscInt *srcdisp_h, *dstdisp_h;
521 PetscInt *srcdisp_d, *dstdisp_d;
522 PetscMPIInt *dstranks_h;
523 PetscMPIInt *dstranks_d;
524 uint64_t *srcsig;
525
526 PetscFunctionBegin;
527 PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
528 if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf */
529 ndstranks = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */
530 src = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
531 dst = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
532
533 srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
534 srcdisp_d = bas->ioffset_d;
535 srcsig = link->rootSendSig;
536
537 dstdisp_h = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
538 dstdisp_d = bas->leafbufdisp_d;
539 dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
540 dstranks_d = bas->iranks_d;
541 } else { /* put data in leafbuf to rootbuf */
542 ndstranks = sf->nRemoteRootRanks;
543 src = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
544 dst = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
545
546 srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
547 srcdisp_d = sf->roffset_d;
548 srcsig = link->leafSendSig;
549
550 dstdisp_h = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
551 dstdisp_d = sf->rootbufdisp_d;
552 dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
553 dstranks_d = sf->ranks_d;
554 }
555
556 /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
557
558 /* Count number of locally accessible neighbors, which should be a small number */
559 for (int i = 0; i < ndstranks; i++) {
560 if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
561 }
562
563 /* For remotely accessible PEs, send data to them in one kernel call */
564 if (nLocallyAccessible < ndstranks) {
565 WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
566 PetscCallCUDA(cudaGetLastError());
567 }
568
569 /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
570 if (nLocallyAccessible) {
571 WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
572 for (int i = 0; i < ndstranks; i++) {
573 int pe = dstranks_h[i];
574 if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
575 size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
576 /* Initiate the nonblocking communication */
577 nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
578 }
579 }
580 }
581
582 if (nLocallyAccessible) nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
583 PetscFunctionReturn(PETSC_SUCCESS);
584 }
585
586 /* A one-thread kernel. The thread takes in charge all remote PEs */
PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt * dstranks,uint64_t * dstsig,PetscInt * dstsigdisp)587 __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp)
588 {
589 /* TODO: Shall we finished the non-blocking remote puts? */
590
591 /* 1. Send a signal to each dst rank */
592
593 /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
594 For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
595 */
596 for (int i = 0; i < ndstranks; i++) nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); /* set sig to 1 */
597
598 /* 2. Wait for signals from src ranks (if any) */
599 if (nsrcranks) {
600 nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
601 for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
602 }
603 }
604
605 /* Finish the communication -- A receiver waits until it can access its receive buffer */
PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)606 static PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
607 {
608 cudaError_t cerr;
609 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
610 PetscMPIInt *dstranks;
611 uint64_t *dstsig;
612 PetscInt nsrcranks, ndstranks, *dstsigdisp;
613
614 PetscFunctionBegin;
615 if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
616 nsrcranks = sf->nRemoteRootRanks;
617
618 ndstranks = bas->nRemoteLeafRanks;
619 dstranks = bas->iranks_d; /* leaf ranks */
620 dstsig = link->leafRecvSig; /* I will set my leaf ranks's RecvSig */
621 dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
622 } else { /* LEAF2ROOT */
623 nsrcranks = bas->nRemoteLeafRanks;
624
625 ndstranks = sf->nRemoteRootRanks;
626 dstranks = sf->ranks_d;
627 dstsig = link->rootRecvSig;
628 dstsigdisp = sf->rootsigdisp_d;
629 }
630
631 if (nsrcranks || ndstranks) {
632 PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
633 PetscCallCUDA(cudaGetLastError());
634 }
635 PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
636 PetscFunctionReturn(PETSC_SUCCESS);
637 }
638
639 /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)640 static PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
641 {
642 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
643 uint64_t *srcsig;
644 PetscInt nsrcranks, *srcsigdisp_d;
645 PetscMPIInt *srcranks_d;
646
647 PetscFunctionBegin;
648 if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
649 nsrcranks = sf->nRemoteRootRanks;
650 srcsig = link->rootSendSig; /* I want to set their send signals */
651 srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
652 srcranks_d = sf->ranks_d; /* ranks of the n root ranks */
653 } else { /* LEAF2ROOT */
654 nsrcranks = bas->nRemoteLeafRanks;
655 srcsig = link->leafSendSig;
656 srcsigdisp_d = bas->leafsigdisp_d;
657 srcranks_d = bas->iranks_d;
658 }
659
660 if (nsrcranks) {
661 NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
662 PetscCallCUDA(cudaGetLastError());
663 }
664 PetscFunctionReturn(PETSC_SUCCESS);
665 }
666
667 /* Destructor when the link uses nvshmem for communication */
PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)668 static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link)
669 {
670 cudaError_t cerr;
671
672 PetscFunctionBegin;
673 PetscCallCUDA(cudaEventDestroy(link->dataReady));
674 PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
675 PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));
676
677 /* nvshmem does not need buffers on host, which should be NULL */
678 PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
679 PetscCall(PetscNvshmemFree(link->leafSendSig));
680 PetscCall(PetscNvshmemFree(link->leafRecvSig));
681 PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
682 PetscCall(PetscNvshmemFree(link->rootSendSig));
683 PetscCall(PetscNvshmemFree(link->rootRecvSig));
684 PetscFunctionReturn(PETSC_SUCCESS);
685 }
686
PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void * rootdata,PetscMemType leafmtype,const void * leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink * mylink)687 PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink)
688 {
689 cudaError_t cerr;
690 PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
691 PetscSFLink *p, link;
692 PetscBool match, rootdirect[2], leafdirect[2];
693 int greatestPriority;
694
695 PetscFunctionBegin;
696 /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
697 We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermediate buffers in local communication with NVSHMEM.
698 */
699 if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
700 if (sf->use_nvshmem_get) {
701 rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
702 leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
703 } else {
704 rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
705 leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
706 }
707 } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
708 if (sf->use_nvshmem_get) {
709 rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
710 leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
711 } else {
712 rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
713 leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
714 }
715 } else { /* PETSCSF_FETCH */
716 rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
717 leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
718 }
719
720 /* Look for free nvshmem links in cache */
721 for (p = &bas->avail; (link = *p); p = &link->next) {
722 if (link->use_nvshmem) {
723 PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match));
724 if (match) {
725 *p = link->next; /* Remove from available list */
726 goto found;
727 }
728 }
729 }
730 PetscCall(PetscNew(&link));
731 PetscCall(PetscSFLinkSetUp_Host(sf, link, unit)); /* Compute link->unitbytes, dup link->unit etc. */
732 if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf, link, unit)); /* Setup pack routines, streams etc */
733 #if defined(PETSC_HAVE_KOKKOS)
734 else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf, link, unit));
735 #endif
736
737 link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
738 link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;
739
740 /* Init signals to zero */
741 if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig));
742 if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig));
743 if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig));
744 if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig));
745
746 link->use_nvshmem = PETSC_TRUE;
747 link->rootmtype = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
748 link->leafmtype = PETSC_MEMTYPE_DEVICE;
749 /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
750 link->Destroy = PetscSFLinkDestroy_NVSHMEM;
751 if (sf->use_nvshmem_get) { /* get-based protocol */
752 link->PrePack = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
753 link->StartCommunication = PetscSFLinkGetDataBegin_NVSHMEM;
754 link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
755 } else { /* put-based protocol */
756 link->StartCommunication = PetscSFLinkPutDataBegin_NVSHMEM;
757 link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
758 link->PostUnpack = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
759 }
760
761 PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority));
762 PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority));
763
764 PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming));
765 PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming));
766
767 found:
768 if (rootdirect[PETSCSF_REMOTE]) {
769 link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
770 } else {
771 if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
772 link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
773 }
774
775 if (leafdirect[PETSCSF_REMOTE]) {
776 link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
777 } else {
778 if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
779 link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
780 }
781
782 link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
783 link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
784 link->rootdata = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
785 link->leafdata = leafdata;
786 link->next = bas->inuse;
787 bas->inuse = link;
788 *mylink = link;
789 PetscFunctionReturn(PETSC_SUCCESS);
790 }
791
792 #if defined(PETSC_USE_REAL_SINGLE)
PetscNvshmemSum(PetscInt count,float * dst,const float * src)793 PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src)
794 {
795 PetscMPIInt num; /* Assume nvshmem's int is MPI's int */
796
797 PetscFunctionBegin;
798 PetscCall(PetscMPIIntCast(count, &num));
799 nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
800 PetscFunctionReturn(PETSC_SUCCESS);
801 }
802
PetscNvshmemMax(PetscInt count,float * dst,const float * src)803 PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src)
804 {
805 PetscMPIInt num;
806
807 PetscFunctionBegin;
808 PetscCall(PetscMPIIntCast(count, &num));
809 nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
810 PetscFunctionReturn(PETSC_SUCCESS);
811 }
812 #elif defined(PETSC_USE_REAL_DOUBLE)
PetscNvshmemSum(PetscInt count,double * dst,const double * src)813 PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src)
814 {
815 PetscMPIInt num;
816
817 PetscFunctionBegin;
818 PetscCall(PetscMPIIntCast(count, &num));
819 nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
820 PetscFunctionReturn(PETSC_SUCCESS);
821 }
822
PetscNvshmemMax(PetscInt count,double * dst,const double * src)823 PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src)
824 {
825 PetscMPIInt num;
826
827 PetscFunctionBegin;
828 PetscCall(PetscMPIIntCast(count, &num));
829 nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
830 PetscFunctionReturn(PETSC_SUCCESS);
831 }
832 #endif
833