xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision 834855d6effb0d027771461c8e947ee1ce5a1e17)
171438e86SJunchao Zhang #include <petsc/private/cudavecimpl.h>
271438e86SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h>
371438e86SJunchao Zhang #include <mpi.h>
471438e86SJunchao Zhang #include <nvshmem.h>
571438e86SJunchao Zhang #include <nvshmemx.h>
671438e86SJunchao Zhang 
PetscNvshmemInitializeCheck(void)7d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemInitializeCheck(void)
8d71ae5a4SJacob Faibussowitsch {
971438e86SJunchao Zhang   PetscFunctionBegin;
1071438e86SJunchao Zhang   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
1171438e86SJunchao Zhang     nvshmemx_init_attr_t attr;
1271438e86SJunchao Zhang     attr.mpi_comm = &PETSC_COMM_WORLD;
139566063dSJacob Faibussowitsch     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
149566063dSJacob Faibussowitsch     PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr));
1571438e86SJunchao Zhang     PetscNvshmemInitialized = PETSC_TRUE;
1671438e86SJunchao Zhang     PetscBeganNvshmem       = PETSC_TRUE;
1771438e86SJunchao Zhang   }
183ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1971438e86SJunchao Zhang }
2071438e86SJunchao Zhang 
PetscNvshmemMalloc(size_t size,void ** ptr)21d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr)
22d71ae5a4SJacob Faibussowitsch {
2371438e86SJunchao Zhang   PetscFunctionBegin;
249566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemInitializeCheck());
2571438e86SJunchao Zhang   *ptr = nvshmem_malloc(size);
2608401ef6SPierre Jolivet   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_malloc() failed to allocate %zu bytes", size);
273ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2871438e86SJunchao Zhang }
2971438e86SJunchao Zhang 
PetscNvshmemCalloc(size_t size,void ** ptr)30d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr)
31d71ae5a4SJacob Faibussowitsch {
3271438e86SJunchao Zhang   PetscFunctionBegin;
339566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemInitializeCheck());
3471438e86SJunchao Zhang   *ptr = nvshmem_calloc(size, 1);
3508401ef6SPierre Jolivet   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_calloc() failed to allocate %zu bytes", size);
363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3771438e86SJunchao Zhang }
3871438e86SJunchao Zhang 
PetscNvshmemFree_Private(void * ptr)39d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemFree_Private(void *ptr)
40d71ae5a4SJacob Faibussowitsch {
4171438e86SJunchao Zhang   PetscFunctionBegin;
4271438e86SJunchao Zhang   nvshmem_free(ptr);
433ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4471438e86SJunchao Zhang }
4571438e86SJunchao Zhang 
PetscNvshmemFinalize(void)46d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemFinalize(void)
47d71ae5a4SJacob Faibussowitsch {
4871438e86SJunchao Zhang   PetscFunctionBegin;
4971438e86SJunchao Zhang   nvshmem_finalize();
503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5171438e86SJunchao Zhang }
5271438e86SJunchao Zhang 
5371438e86SJunchao Zhang /* Free nvshmem related fields in the SF */
PetscSFReset_Basic_NVSHMEM(PetscSF sf)54d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
55d71ae5a4SJacob Faibussowitsch {
5671438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
5771438e86SJunchao Zhang 
5871438e86SJunchao Zhang   PetscFunctionBegin;
599566063dSJacob Faibussowitsch   PetscCall(PetscFree2(bas->leafsigdisp, bas->leafbufdisp));
609566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d));
619566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d));
629566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d));
639566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d));
6471438e86SJunchao Zhang 
659566063dSJacob Faibussowitsch   PetscCall(PetscFree2(sf->rootsigdisp, sf->rootbufdisp));
669566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d));
679566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d));
689566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d));
699566063dSJacob Faibussowitsch   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d));
703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
7171438e86SJunchao Zhang }
7271438e86SJunchao Zhang 
73be87f6c0SPierre Jolivet /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependent fields) */
PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)74d71ae5a4SJacob Faibussowitsch static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
75d71ae5a4SJacob Faibussowitsch {
7671438e86SJunchao Zhang   cudaError_t    cerr;
7771438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
7871438e86SJunchao Zhang   PetscInt       i, nRemoteRootRanks, nRemoteLeafRanks;
7971438e86SJunchao Zhang   PetscMPIInt    tag;
8071438e86SJunchao Zhang   MPI_Comm       comm;
8171438e86SJunchao Zhang   MPI_Request   *rootreqs, *leafreqs;
8271438e86SJunchao Zhang   PetscInt       tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */
8371438e86SJunchao Zhang 
8471438e86SJunchao Zhang   PetscFunctionBegin;
859566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
869566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetNewTag((PetscObject)sf, &tag));
8771438e86SJunchao Zhang 
8871438e86SJunchao Zhang   nRemoteRootRanks      = sf->nranks - sf->ndranks;
8971438e86SJunchao Zhang   nRemoteLeafRanks      = bas->niranks - bas->ndiranks;
9071438e86SJunchao Zhang   sf->nRemoteRootRanks  = nRemoteRootRanks;
9171438e86SJunchao Zhang   bas->nRemoteLeafRanks = nRemoteLeafRanks;
9271438e86SJunchao Zhang 
939566063dSJacob Faibussowitsch   PetscCall(PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs));
9471438e86SJunchao Zhang 
9571438e86SJunchao Zhang   stmp[0] = nRemoteRootRanks;
9671438e86SJunchao Zhang   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
9771438e86SJunchao Zhang   stmp[2] = nRemoteLeafRanks;
9871438e86SJunchao Zhang   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
9971438e86SJunchao Zhang 
100462c564dSBarry Smith   PetscCallMPI(MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm));
10171438e86SJunchao Zhang 
10271438e86SJunchao Zhang   sf->nRemoteRootRanksMax  = rtmp[0];
10371438e86SJunchao Zhang   sf->leafbuflen_rmax      = rtmp[1];
10471438e86SJunchao Zhang   bas->nRemoteLeafRanksMax = rtmp[2];
10571438e86SJunchao Zhang   bas->rootbuflen_rmax     = rtmp[3];
10671438e86SJunchao Zhang 
10771438e86SJunchao Zhang   /* Total four rounds of MPI communications to set up the nvshmem fields */
10871438e86SJunchao Zhang 
10971438e86SJunchao Zhang   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
1109566063dSJacob Faibussowitsch   PetscCall(PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp));
1116497c311SBarry Smith   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
1129566063dSJacob Faibussowitsch   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm));                              /* Roots send. Note i changes, so we use MPI_Send. */
1139566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
11471438e86SJunchao Zhang 
1156497c311SBarry Smith   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
11671438e86SJunchao Zhang   for (i = 0; i < nRemoteLeafRanks; i++) {
11771438e86SJunchao Zhang     tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
1189566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note tmp changes, so we use MPI_Send. */
11971438e86SJunchao Zhang   }
1209566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));
12171438e86SJunchao Zhang 
1229566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
1239566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
1249566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt)));
1259566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt)));
12671438e86SJunchao Zhang 
1279566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
1289566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
1299566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
1309566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
13171438e86SJunchao Zhang 
13271438e86SJunchao Zhang   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
1339566063dSJacob Faibussowitsch   PetscCall(PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp));
1346497c311SBarry Smith   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
1359566063dSJacob Faibussowitsch   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
1369566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
13771438e86SJunchao Zhang 
1386497c311SBarry Smith   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
13971438e86SJunchao Zhang   for (i = 0; i < nRemoteRootRanks; i++) {
14071438e86SJunchao Zhang     tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
1419566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
14271438e86SJunchao Zhang   }
1439566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));
14471438e86SJunchao Zhang 
1459566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
1469566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
1479566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt)));
1489566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt)));
14971438e86SJunchao Zhang 
1509566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
1519566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
1529566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
1539566063dSJacob Faibussowitsch   PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
15471438e86SJunchao Zhang 
1559566063dSJacob Faibussowitsch   PetscCall(PetscFree2(rootreqs, leafreqs));
1563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
15771438e86SJunchao Zhang }
15871438e86SJunchao Zhang 
PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void * rootdata,PetscMemType leafmtype,const void * leafdata,PetscBool * use_nvshmem)159d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem)
160d71ae5a4SJacob Faibussowitsch {
16171438e86SJunchao Zhang   MPI_Comm    comm;
16271438e86SJunchao Zhang   PetscBool   isBasic;
16371438e86SJunchao Zhang   PetscMPIInt result = MPI_UNEQUAL;
16471438e86SJunchao Zhang 
16571438e86SJunchao Zhang   PetscFunctionBegin;
1669566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
16771438e86SJunchao Zhang   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
16871438e86SJunchao Zhang      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
16971438e86SJunchao Zhang   */
17071438e86SJunchao Zhang   sf->checked_nvshmem_eligibility = PETSC_TRUE;
17171438e86SJunchao Zhang   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
17271438e86SJunchao Zhang     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
1739566063dSJacob Faibussowitsch     PetscCall(PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic));
1749566063dSJacob Faibussowitsch     if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result));
17571438e86SJunchao Zhang     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
17671438e86SJunchao Zhang 
17771438e86SJunchao Zhang     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
17871438e86SJunchao Zhang        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
17971438e86SJunchao Zhang        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
18071438e86SJunchao Zhang     */
18171438e86SJunchao Zhang     if (sf->use_nvshmem) {
18271438e86SJunchao Zhang       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
183462c564dSBarry Smith       PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm));
18471438e86SJunchao Zhang       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
18571438e86SJunchao Zhang     }
18671438e86SJunchao Zhang     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
18771438e86SJunchao Zhang   }
18871438e86SJunchao Zhang 
18971438e86SJunchao Zhang   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
19071438e86SJunchao Zhang   if (sf->use_nvshmem) {
19171438e86SJunchao Zhang     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
19271438e86SJunchao Zhang     PetscInt allCuda = oneCuda;                                                                                          /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
19371438e86SJunchao Zhang #if defined(PETSC_USE_DEBUG)                                                                                             /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
194462c564dSBarry Smith     PetscCallMPI(MPIU_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm));
19508401ef6SPierre Jolivet     PetscCheck(allCuda == oneCuda, comm, PETSC_ERR_SUP, "root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
19671438e86SJunchao Zhang #endif
19771438e86SJunchao Zhang     if (allCuda) {
1989566063dSJacob Faibussowitsch       PetscCall(PetscNvshmemInitializeCheck());
19971438e86SJunchao Zhang       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
2009566063dSJacob Faibussowitsch         PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
20171438e86SJunchao Zhang         sf->setup_nvshmem = PETSC_TRUE;
20271438e86SJunchao Zhang       }
20371438e86SJunchao Zhang       *use_nvshmem = PETSC_TRUE;
20471438e86SJunchao Zhang     } else {
20571438e86SJunchao Zhang       *use_nvshmem = PETSC_FALSE;
20671438e86SJunchao Zhang     }
20771438e86SJunchao Zhang   } else {
20871438e86SJunchao Zhang     *use_nvshmem = PETSC_FALSE;
20971438e86SJunchao Zhang   }
2103ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
21171438e86SJunchao Zhang }
21271438e86SJunchao Zhang 
21371438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)214d71ae5a4SJacob Faibussowitsch static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
215d71ae5a4SJacob Faibussowitsch {
21671438e86SJunchao Zhang   cudaError_t    cerr;
21771438e86SJunchao Zhang   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
21871438e86SJunchao Zhang   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
21971438e86SJunchao Zhang 
22071438e86SJunchao Zhang   PetscFunctionBegin;
22171438e86SJunchao Zhang   if (buflen) {
2229566063dSJacob Faibussowitsch     PetscCallCUDA(cudaEventRecord(link->dataReady, link->stream));
2239566063dSJacob Faibussowitsch     PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0));
22471438e86SJunchao Zhang   }
2253ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
22671438e86SJunchao Zhang }
22771438e86SJunchao Zhang 
22871438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)229d71ae5a4SJacob Faibussowitsch static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
230d71ae5a4SJacob Faibussowitsch {
23171438e86SJunchao Zhang   cudaError_t    cerr;
23271438e86SJunchao Zhang   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
23371438e86SJunchao Zhang   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
23471438e86SJunchao Zhang 
23571438e86SJunchao Zhang   PetscFunctionBegin;
236da81f932SPierre Jolivet   /* If unpack to non-null device buffer, build the endRemoteComm dependence */
23771438e86SJunchao Zhang   if (buflen) {
2389566063dSJacob Faibussowitsch     PetscCallCUDA(cudaEventRecord(link->endRemoteComm, link->remoteCommStream));
2399566063dSJacob Faibussowitsch     PetscCallCUDA(cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0));
24071438e86SJunchao Zhang   }
2413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
24271438e86SJunchao Zhang }
24371438e86SJunchao Zhang 
24471438e86SJunchao Zhang /* Send/Put signals to remote ranks
24571438e86SJunchao Zhang 
24671438e86SJunchao Zhang  Input parameters:
24771438e86SJunchao Zhang   + n        - Number of remote ranks
24871438e86SJunchao Zhang   . sig      - Signal address in symmetric heap
24971438e86SJunchao Zhang   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
25071438e86SJunchao Zhang   . ranks    - remote ranks
25171438e86SJunchao Zhang   - newval   - Set signals to this value
25271438e86SJunchao Zhang */
NvshmemSendSignals(PetscInt n,uint64_t * sig,PetscInt * sigdisp,PetscMPIInt * ranks,uint64_t newval)253d71ae5a4SJacob Faibussowitsch __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval)
254d71ae5a4SJacob Faibussowitsch {
25571438e86SJunchao Zhang   int i = blockIdx.x * blockDim.x + threadIdx.x;
25671438e86SJunchao Zhang 
25771438e86SJunchao Zhang   /* Each thread puts one remote signal */
25871438e86SJunchao Zhang   if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
25971438e86SJunchao Zhang }
26071438e86SJunchao Zhang 
26171438e86SJunchao Zhang /* Wait until local signals equal to the expected value and then set them to a new value
26271438e86SJunchao Zhang 
26371438e86SJunchao Zhang  Input parameters:
26471438e86SJunchao Zhang   + n        - Number of signals
26571438e86SJunchao Zhang   . sig      - Local signal address
26671438e86SJunchao Zhang   . expval   - expected value
26771438e86SJunchao Zhang   - newval   - Set signals to this new value
26871438e86SJunchao Zhang */
NvshmemWaitSignals(PetscInt n,uint64_t * sig,uint64_t expval,uint64_t newval)269d71ae5a4SJacob Faibussowitsch __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval)
270d71ae5a4SJacob Faibussowitsch {
27171438e86SJunchao Zhang #if 0
27271438e86SJunchao Zhang   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
27371438e86SJunchao Zhang   int i = blockIdx.x*blockDim.x + threadIdx.x;
27471438e86SJunchao Zhang   if (i < n) {
27571438e86SJunchao Zhang     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
27671438e86SJunchao Zhang     sig[i] = newval;
27771438e86SJunchao Zhang   }
27871438e86SJunchao Zhang #else
27971438e86SJunchao Zhang   nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
28071438e86SJunchao Zhang   for (int i = 0; i < n; i++) sig[i] = newval;
28171438e86SJunchao Zhang #endif
28271438e86SJunchao Zhang }
28371438e86SJunchao Zhang 
28471438e86SJunchao Zhang /* ===========================================================================================================
28571438e86SJunchao Zhang 
28671438e86SJunchao Zhang    A set of routines to support receiver initiated communication using the get method
28771438e86SJunchao Zhang 
28871438e86SJunchao Zhang     The getting protocol is:
28971438e86SJunchao Zhang 
29071438e86SJunchao Zhang     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
29171438e86SJunchao Zhang     All signal variables have an initial value 0.
29271438e86SJunchao Zhang 
29371438e86SJunchao Zhang     Sender:                                 |  Receiver:
29471438e86SJunchao Zhang   1.  Wait ssig be 0, then set it to 1
29571438e86SJunchao Zhang   2.  Pack data into stand alone sbuf       |
29671438e86SJunchao Zhang   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
29771438e86SJunchao Zhang                                             |   2. Get data from remote sbuf to local rbuf
29871438e86SJunchao Zhang                                             |   3. Put 1 to sender's ssig
29971438e86SJunchao Zhang                                             |   4. Unpack data from local rbuf
30071438e86SJunchao Zhang    ===========================================================================================================*/
30171438e86SJunchao Zhang /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
30271438e86SJunchao Zhang    Sender waits for signals (from receivers) indicating receivers have finished getting data
30371438e86SJunchao Zhang */
PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)30466976f2fSJacob Faibussowitsch static PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
305d71ae5a4SJacob Faibussowitsch {
30671438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
30771438e86SJunchao Zhang   uint64_t      *sig;
30871438e86SJunchao Zhang   PetscInt       n;
30971438e86SJunchao Zhang 
31071438e86SJunchao Zhang   PetscFunctionBegin;
31171438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
31271438e86SJunchao Zhang     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
31371438e86SJunchao Zhang     n   = bas->nRemoteLeafRanks;
31471438e86SJunchao Zhang   } else { /* LEAF2ROOT */
31571438e86SJunchao Zhang     sig = link->leafSendSig;
31671438e86SJunchao Zhang     n   = sf->nRemoteRootRanks;
31771438e86SJunchao Zhang   }
31871438e86SJunchao Zhang 
31971438e86SJunchao Zhang   if (n) {
32071438e86SJunchao Zhang     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
3219566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
32271438e86SJunchao Zhang   }
3233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
32471438e86SJunchao Zhang }
32571438e86SJunchao Zhang 
32671438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */
GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt * srcranks,const char * src,PetscInt * srcdisp,char * dst,PetscInt * dstdisp,PetscInt unitbytes)327d71ae5a4SJacob Faibussowitsch __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes)
328d71ae5a4SJacob Faibussowitsch {
32971438e86SJunchao Zhang   int         bid = blockIdx.x;
33071438e86SJunchao Zhang   PetscMPIInt pe  = srcranks[bid];
33171438e86SJunchao Zhang 
33271438e86SJunchao Zhang   if (!nvshmem_ptr(src, pe)) {
33371438e86SJunchao Zhang     PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
33471438e86SJunchao Zhang     nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
33571438e86SJunchao Zhang   }
33671438e86SJunchao Zhang }
33771438e86SJunchao Zhang 
33871438e86SJunchao Zhang /* Start communication -- Get data in the given direction */
PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)33966976f2fSJacob Faibussowitsch static PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
340d71ae5a4SJacob Faibussowitsch {
34171438e86SJunchao Zhang   cudaError_t    cerr;
34271438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
34371438e86SJunchao Zhang 
34471438e86SJunchao Zhang   PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;
34571438e86SJunchao Zhang 
34671438e86SJunchao Zhang   char        *src, *dst;
34771438e86SJunchao Zhang   PetscInt    *srcdisp_h, *dstdisp_h;
34871438e86SJunchao Zhang   PetscInt    *srcdisp_d, *dstdisp_d;
34971438e86SJunchao Zhang   PetscMPIInt *srcranks_h;
35071438e86SJunchao Zhang   PetscMPIInt *srcranks_d, *dstranks_d;
35171438e86SJunchao Zhang   uint64_t    *dstsig;
35271438e86SJunchao Zhang   PetscInt    *dstsigdisp_d;
35371438e86SJunchao Zhang 
35471438e86SJunchao Zhang   PetscFunctionBegin;
3559566063dSJacob Faibussowitsch   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
35671438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
35771438e86SJunchao Zhang     nsrcranks = sf->nRemoteRootRanks;
35871438e86SJunchao Zhang     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
35971438e86SJunchao Zhang 
36071438e86SJunchao Zhang     srcdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
36171438e86SJunchao Zhang     srcdisp_d  = sf->rootbufdisp_d;
36271438e86SJunchao Zhang     srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
36371438e86SJunchao Zhang     srcranks_d = sf->ranks_d;
36471438e86SJunchao Zhang 
36571438e86SJunchao Zhang     ndstranks = bas->nRemoteLeafRanks;
36671438e86SJunchao Zhang     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
36771438e86SJunchao Zhang 
36871438e86SJunchao Zhang     dstdisp_h  = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
36971438e86SJunchao Zhang     dstdisp_d  = sf->roffset_d;
37071438e86SJunchao Zhang     dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */
37171438e86SJunchao Zhang 
37271438e86SJunchao Zhang     dstsig       = link->leafRecvSig;
37371438e86SJunchao Zhang     dstsigdisp_d = bas->leafsigdisp_d;
37471438e86SJunchao Zhang   } else { /* src is leaf, dst is root; we will move data from src to dst */
37571438e86SJunchao Zhang     nsrcranks = bas->nRemoteLeafRanks;
37671438e86SJunchao Zhang     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
37771438e86SJunchao Zhang 
37871438e86SJunchao Zhang     srcdisp_h  = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
37971438e86SJunchao Zhang     srcdisp_d  = bas->leafbufdisp_d;
38071438e86SJunchao Zhang     srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
38171438e86SJunchao Zhang     srcranks_d = bas->iranks_d;
38271438e86SJunchao Zhang 
38371438e86SJunchao Zhang     ndstranks = sf->nRemoteRootRanks;
38471438e86SJunchao Zhang     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
38571438e86SJunchao Zhang 
38671438e86SJunchao Zhang     dstdisp_h  = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
38771438e86SJunchao Zhang     dstdisp_d  = bas->ioffset_d;
38871438e86SJunchao Zhang     dstranks_d = sf->ranks_d; /* my (remote) root ranks */
38971438e86SJunchao Zhang 
39071438e86SJunchao Zhang     dstsig       = link->rootRecvSig;
39171438e86SJunchao Zhang     dstsigdisp_d = sf->rootsigdisp_d;
39271438e86SJunchao Zhang   }
39371438e86SJunchao Zhang 
39471438e86SJunchao Zhang   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
39571438e86SJunchao Zhang   if (ndstranks) {
39671438e86SJunchao Zhang     NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
3979566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
39871438e86SJunchao Zhang   }
39971438e86SJunchao Zhang 
40071438e86SJunchao Zhang   /* dst waits for signals (permissions) from src ranks to start getting data */
40171438e86SJunchao Zhang   if (nsrcranks) {
40271438e86SJunchao Zhang     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
4039566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
40471438e86SJunchao Zhang   }
40571438e86SJunchao Zhang 
40671438e86SJunchao Zhang   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
40771438e86SJunchao Zhang 
40871438e86SJunchao Zhang   /* Count number of locally accessible src ranks, which should be a small number */
4099371c9d4SSatish Balay   for (int i = 0; i < nsrcranks; i++) {
4109371c9d4SSatish Balay     if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
4119371c9d4SSatish Balay   }
41271438e86SJunchao Zhang 
41371438e86SJunchao Zhang   /* Get data from remotely accessible PEs */
41471438e86SJunchao Zhang   if (nLocallyAccessible < nsrcranks) {
41571438e86SJunchao Zhang     GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
4169566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
41771438e86SJunchao Zhang   }
41871438e86SJunchao Zhang 
41971438e86SJunchao Zhang   /* Get data from locally accessible PEs */
42071438e86SJunchao Zhang   if (nLocallyAccessible) {
42171438e86SJunchao Zhang     for (int i = 0; i < nsrcranks; i++) {
42271438e86SJunchao Zhang       int pe = srcranks_h[i];
42371438e86SJunchao Zhang       if (nvshmem_ptr(src, pe)) {
42471438e86SJunchao Zhang         size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
42571438e86SJunchao Zhang         nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
42671438e86SJunchao Zhang       }
42771438e86SJunchao Zhang     }
42871438e86SJunchao Zhang   }
4293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
43071438e86SJunchao Zhang }
43171438e86SJunchao Zhang 
43271438e86SJunchao Zhang /* Finish the communication (can be done before Unpack)
43371438e86SJunchao Zhang    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
43471438e86SJunchao Zhang */
PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)43566976f2fSJacob Faibussowitsch static PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
436d71ae5a4SJacob Faibussowitsch {
43771438e86SJunchao Zhang   cudaError_t    cerr;
43871438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
43971438e86SJunchao Zhang   uint64_t      *srcsig;
44071438e86SJunchao Zhang   PetscInt       nsrcranks, *srcsigdisp;
44171438e86SJunchao Zhang   PetscMPIInt   *srcranks;
44271438e86SJunchao Zhang 
44371438e86SJunchao Zhang   PetscFunctionBegin;
44471438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
44571438e86SJunchao Zhang     nsrcranks  = sf->nRemoteRootRanks;
44671438e86SJunchao Zhang     srcsig     = link->rootSendSig; /* I want to set their root signal */
44771438e86SJunchao Zhang     srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
44871438e86SJunchao Zhang     srcranks   = sf->ranks_d;       /* ranks of the n root ranks */
44971438e86SJunchao Zhang   } else {                          /* LEAF2ROOT, root ranks are getting data */
45071438e86SJunchao Zhang     nsrcranks  = bas->nRemoteLeafRanks;
45171438e86SJunchao Zhang     srcsig     = link->leafSendSig;
45271438e86SJunchao Zhang     srcsigdisp = bas->leafsigdisp_d;
45371438e86SJunchao Zhang     srcranks   = bas->iranks_d;
45471438e86SJunchao Zhang   }
45571438e86SJunchao Zhang 
45671438e86SJunchao Zhang   if (nsrcranks) {
45771438e86SJunchao Zhang     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
4589566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
45971438e86SJunchao Zhang     NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
4609566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
46171438e86SJunchao Zhang   }
4629566063dSJacob Faibussowitsch   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
4633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
46471438e86SJunchao Zhang }
46571438e86SJunchao Zhang 
46671438e86SJunchao Zhang /* ===========================================================================================================
46771438e86SJunchao Zhang 
46871438e86SJunchao Zhang    A set of routines to support sender initiated communication using the put-based method (the default)
46971438e86SJunchao Zhang 
47071438e86SJunchao Zhang     The putting protocol is:
47171438e86SJunchao Zhang 
47271438e86SJunchao Zhang     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
47371438e86SJunchao Zhang     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
47471438e86SJunchao Zhang     is in nvshmem space.
47571438e86SJunchao Zhang 
47671438e86SJunchao Zhang     Sender:                                 |  Receiver:
47771438e86SJunchao Zhang                                             |
47871438e86SJunchao Zhang   1.  Pack data into sbuf                   |
47971438e86SJunchao Zhang   2.  Wait ssig be 0, then set it to 1      |
48071438e86SJunchao Zhang   3.  Put data to remote stand-alone rbuf   |
48171438e86SJunchao Zhang   4.  Fence // make sure 5 happens after 3  |
48271438e86SJunchao Zhang   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
48371438e86SJunchao Zhang                                             |   2. Unpack data from local rbuf
48471438e86SJunchao Zhang                                             |   3. Put 0 to sender's ssig
48571438e86SJunchao Zhang    ===========================================================================================================*/
48671438e86SJunchao Zhang 
48771438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */
WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt * dstranks,char * dst,PetscInt * dstdisp,const char * src,PetscInt * srcdisp,uint64_t * srcsig,PetscInt unitbytes)488d71ae5a4SJacob Faibussowitsch __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes)
489d71ae5a4SJacob Faibussowitsch {
49071438e86SJunchao Zhang   int         bid = blockIdx.x;
49171438e86SJunchao Zhang   PetscMPIInt pe  = dstranks[bid];
49271438e86SJunchao Zhang 
49371438e86SJunchao Zhang   if (!nvshmem_ptr(dst, pe)) {
49471438e86SJunchao Zhang     PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
49571438e86SJunchao Zhang     nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
49671438e86SJunchao Zhang     srcsig[bid] = 1;
49771438e86SJunchao Zhang     nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
49871438e86SJunchao Zhang   }
49971438e86SJunchao Zhang }
50071438e86SJunchao Zhang 
501da81f932SPierre Jolivet /* one-thread kernel, which takes in charge all locally accessible */
WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt * dstranks,uint64_t * srcsig,const char * dst)502d71ae5a4SJacob Faibussowitsch __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst)
503d71ae5a4SJacob Faibussowitsch {
50471438e86SJunchao Zhang   for (int i = 0; i < ndstranks; i++) {
50571438e86SJunchao Zhang     int pe = dstranks[i];
50671438e86SJunchao Zhang     if (nvshmem_ptr(dst, pe)) {
50771438e86SJunchao Zhang       nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
50871438e86SJunchao Zhang       srcsig[i] = 1;
50971438e86SJunchao Zhang     }
51071438e86SJunchao Zhang   }
51171438e86SJunchao Zhang }
51271438e86SJunchao Zhang 
51371438e86SJunchao Zhang /* Put data in the given direction  */
PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)51466976f2fSJacob Faibussowitsch static PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
515d71ae5a4SJacob Faibussowitsch {
51671438e86SJunchao Zhang   cudaError_t    cerr;
51771438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
51871438e86SJunchao Zhang   PetscInt       ndstranks, nLocallyAccessible = 0;
51971438e86SJunchao Zhang   char          *src, *dst;
52071438e86SJunchao Zhang   PetscInt      *srcdisp_h, *dstdisp_h;
52171438e86SJunchao Zhang   PetscInt      *srcdisp_d, *dstdisp_d;
52271438e86SJunchao Zhang   PetscMPIInt   *dstranks_h;
52371438e86SJunchao Zhang   PetscMPIInt   *dstranks_d;
52471438e86SJunchao Zhang   uint64_t      *srcsig;
52571438e86SJunchao Zhang 
52671438e86SJunchao Zhang   PetscFunctionBegin;
5279566063dSJacob Faibussowitsch   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
52871438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) {                              /* put data in rootbuf to leafbuf  */
52971438e86SJunchao Zhang     ndstranks = bas->nRemoteLeafRanks;                               /* number of (remote) leaf ranks */
53071438e86SJunchao Zhang     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
53171438e86SJunchao Zhang     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
53271438e86SJunchao Zhang 
53371438e86SJunchao Zhang     srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
53471438e86SJunchao Zhang     srcdisp_d = bas->ioffset_d;
53571438e86SJunchao Zhang     srcsig    = link->rootSendSig;
53671438e86SJunchao Zhang 
53771438e86SJunchao Zhang     dstdisp_h  = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
53871438e86SJunchao Zhang     dstdisp_d  = bas->leafbufdisp_d;
53971438e86SJunchao Zhang     dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
54071438e86SJunchao Zhang     dstranks_d = bas->iranks_d;
54171438e86SJunchao Zhang   } else { /* put data in leafbuf to rootbuf */
54271438e86SJunchao Zhang     ndstranks = sf->nRemoteRootRanks;
54371438e86SJunchao Zhang     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
54471438e86SJunchao Zhang     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
54571438e86SJunchao Zhang 
54671438e86SJunchao Zhang     srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
54771438e86SJunchao Zhang     srcdisp_d = sf->roffset_d;
54871438e86SJunchao Zhang     srcsig    = link->leafSendSig;
54971438e86SJunchao Zhang 
55071438e86SJunchao Zhang     dstdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
55171438e86SJunchao Zhang     dstdisp_d  = sf->rootbufdisp_d;
55271438e86SJunchao Zhang     dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
55371438e86SJunchao Zhang     dstranks_d = sf->ranks_d;
55471438e86SJunchao Zhang   }
55571438e86SJunchao Zhang 
55671438e86SJunchao Zhang   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
55771438e86SJunchao Zhang 
55871438e86SJunchao Zhang   /* Count number of locally accessible neighbors, which should be a small number */
5599371c9d4SSatish Balay   for (int i = 0; i < ndstranks; i++) {
5609371c9d4SSatish Balay     if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
5619371c9d4SSatish Balay   }
56271438e86SJunchao Zhang 
56371438e86SJunchao Zhang   /* For remotely accessible PEs, send data to them in one kernel call */
56471438e86SJunchao Zhang   if (nLocallyAccessible < ndstranks) {
56571438e86SJunchao Zhang     WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
5669566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
56771438e86SJunchao Zhang   }
56871438e86SJunchao Zhang 
56971438e86SJunchao Zhang   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
57071438e86SJunchao Zhang   if (nLocallyAccessible) {
57171438e86SJunchao Zhang     WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
57271438e86SJunchao Zhang     for (int i = 0; i < ndstranks; i++) {
57371438e86SJunchao Zhang       int pe = dstranks_h[i];
57471438e86SJunchao Zhang       if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
57571438e86SJunchao Zhang         size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
57671438e86SJunchao Zhang         /* Initiate the nonblocking communication */
57771438e86SJunchao Zhang         nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
57871438e86SJunchao Zhang       }
57971438e86SJunchao Zhang     }
58071438e86SJunchao Zhang   }
58171438e86SJunchao Zhang 
582*ac530a7eSPierre Jolivet   if (nLocallyAccessible) nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
5833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
58471438e86SJunchao Zhang }
58571438e86SJunchao Zhang 
58671438e86SJunchao Zhang /* A one-thread kernel. The thread takes in charge all remote PEs */
PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt * dstranks,uint64_t * dstsig,PetscInt * dstsigdisp)587d71ae5a4SJacob Faibussowitsch __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp)
588d71ae5a4SJacob Faibussowitsch {
58971438e86SJunchao Zhang   /* TODO: Shall we finished the non-blocking remote puts? */
59071438e86SJunchao Zhang 
59171438e86SJunchao Zhang   /* 1. Send a signal to each dst rank */
59271438e86SJunchao Zhang 
59371438e86SJunchao Zhang   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
59471438e86SJunchao Zhang      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
59571438e86SJunchao Zhang   */
596ad540459SPierre Jolivet   for (int i = 0; i < ndstranks; i++) nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); /* set sig to 1 */
59771438e86SJunchao Zhang 
59871438e86SJunchao Zhang   /* 2. Wait for signals from src ranks (if any) */
59971438e86SJunchao Zhang   if (nsrcranks) {
60071438e86SJunchao Zhang     nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
60171438e86SJunchao Zhang     for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
60271438e86SJunchao Zhang   }
60371438e86SJunchao Zhang }
60471438e86SJunchao Zhang 
60571438e86SJunchao Zhang /* Finish the communication -- A receiver waits until it can access its receive buffer */
PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)60666976f2fSJacob Faibussowitsch static PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
607d71ae5a4SJacob Faibussowitsch {
60871438e86SJunchao Zhang   cudaError_t    cerr;
60971438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
61071438e86SJunchao Zhang   PetscMPIInt   *dstranks;
61171438e86SJunchao Zhang   uint64_t      *dstsig;
61271438e86SJunchao Zhang   PetscInt       nsrcranks, ndstranks, *dstsigdisp;
61371438e86SJunchao Zhang 
61471438e86SJunchao Zhang   PetscFunctionBegin;
61571438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
61671438e86SJunchao Zhang     nsrcranks = sf->nRemoteRootRanks;
61771438e86SJunchao Zhang 
61871438e86SJunchao Zhang     ndstranks  = bas->nRemoteLeafRanks;
61971438e86SJunchao Zhang     dstranks   = bas->iranks_d;      /* leaf ranks */
62071438e86SJunchao Zhang     dstsig     = link->leafRecvSig;  /* I will set my leaf ranks's RecvSig */
62171438e86SJunchao Zhang     dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
62271438e86SJunchao Zhang   } else {                           /* LEAF2ROOT */
62371438e86SJunchao Zhang     nsrcranks = bas->nRemoteLeafRanks;
62471438e86SJunchao Zhang 
62571438e86SJunchao Zhang     ndstranks  = sf->nRemoteRootRanks;
62671438e86SJunchao Zhang     dstranks   = sf->ranks_d;
62771438e86SJunchao Zhang     dstsig     = link->rootRecvSig;
62871438e86SJunchao Zhang     dstsigdisp = sf->rootsigdisp_d;
62971438e86SJunchao Zhang   }
63071438e86SJunchao Zhang 
63171438e86SJunchao Zhang   if (nsrcranks || ndstranks) {
63271438e86SJunchao Zhang     PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
6339566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
63471438e86SJunchao Zhang   }
6359566063dSJacob Faibussowitsch   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
6363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
63771438e86SJunchao Zhang }
63871438e86SJunchao Zhang 
63971438e86SJunchao Zhang /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)64066976f2fSJacob Faibussowitsch static PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
641d71ae5a4SJacob Faibussowitsch {
64271438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
64371438e86SJunchao Zhang   uint64_t      *srcsig;
64471438e86SJunchao Zhang   PetscInt       nsrcranks, *srcsigdisp_d;
64571438e86SJunchao Zhang   PetscMPIInt   *srcranks_d;
64671438e86SJunchao Zhang 
64771438e86SJunchao Zhang   PetscFunctionBegin;
64871438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
64971438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
65071438e86SJunchao Zhang     srcsig       = link->rootSendSig; /* I want to set their send signals */
65171438e86SJunchao Zhang     srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
65271438e86SJunchao Zhang     srcranks_d   = sf->ranks_d;       /* ranks of the n root ranks */
65371438e86SJunchao Zhang   } else {                            /* LEAF2ROOT */
65471438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
65571438e86SJunchao Zhang     srcsig       = link->leafSendSig;
65671438e86SJunchao Zhang     srcsigdisp_d = bas->leafsigdisp_d;
65771438e86SJunchao Zhang     srcranks_d   = bas->iranks_d;
65871438e86SJunchao Zhang   }
65971438e86SJunchao Zhang 
66071438e86SJunchao Zhang   if (nsrcranks) {
66171438e86SJunchao Zhang     NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
6629566063dSJacob Faibussowitsch     PetscCallCUDA(cudaGetLastError());
66371438e86SJunchao Zhang   }
6643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
66571438e86SJunchao Zhang }
66671438e86SJunchao Zhang 
66771438e86SJunchao Zhang /* Destructor when the link uses nvshmem for communication */
PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)668d71ae5a4SJacob Faibussowitsch static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link)
669d71ae5a4SJacob Faibussowitsch {
67071438e86SJunchao Zhang   cudaError_t cerr;
67171438e86SJunchao Zhang 
67271438e86SJunchao Zhang   PetscFunctionBegin;
6739566063dSJacob Faibussowitsch   PetscCallCUDA(cudaEventDestroy(link->dataReady));
6749566063dSJacob Faibussowitsch   PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
6759566063dSJacob Faibussowitsch   PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));
67671438e86SJunchao Zhang 
67771438e86SJunchao Zhang   /* nvshmem does not need buffers on host, which should be NULL */
6789566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
6799566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemFree(link->leafSendSig));
6809566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemFree(link->leafRecvSig));
6819566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
6829566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemFree(link->rootSendSig));
6839566063dSJacob Faibussowitsch   PetscCall(PetscNvshmemFree(link->rootRecvSig));
6843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
68571438e86SJunchao Zhang }
68671438e86SJunchao Zhang 
PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void * rootdata,PetscMemType leafmtype,const void * leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink * mylink)687d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink)
688d71ae5a4SJacob Faibussowitsch {
68971438e86SJunchao Zhang   cudaError_t    cerr;
69071438e86SJunchao Zhang   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
69171438e86SJunchao Zhang   PetscSFLink   *p, link;
69271438e86SJunchao Zhang   PetscBool      match, rootdirect[2], leafdirect[2];
69371438e86SJunchao Zhang   int            greatestPriority;
69471438e86SJunchao Zhang 
69571438e86SJunchao Zhang   PetscFunctionBegin;
69671438e86SJunchao Zhang   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
697da81f932SPierre Jolivet      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermediate buffers in local communication with NVSHMEM.
69871438e86SJunchao Zhang   */
69971438e86SJunchao Zhang   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
70071438e86SJunchao Zhang     if (sf->use_nvshmem_get) {
70171438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
70271438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
70371438e86SJunchao Zhang     } else {
70471438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
70571438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
70671438e86SJunchao Zhang     }
70771438e86SJunchao Zhang   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
70871438e86SJunchao Zhang     if (sf->use_nvshmem_get) {
70971438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
71071438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
71171438e86SJunchao Zhang     } else {
71271438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
71371438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
71471438e86SJunchao Zhang     }
71571438e86SJunchao Zhang   } else {                                    /* PETSCSF_FETCH */
71671438e86SJunchao Zhang     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
71771438e86SJunchao Zhang     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
71871438e86SJunchao Zhang   }
71971438e86SJunchao Zhang 
72071438e86SJunchao Zhang   /* Look for free nvshmem links in cache */
72171438e86SJunchao Zhang   for (p = &bas->avail; (link = *p); p = &link->next) {
72271438e86SJunchao Zhang     if (link->use_nvshmem) {
7239566063dSJacob Faibussowitsch       PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match));
72471438e86SJunchao Zhang       if (match) {
72571438e86SJunchao Zhang         *p = link->next; /* Remove from available list */
72671438e86SJunchao Zhang         goto found;
72771438e86SJunchao Zhang       }
72871438e86SJunchao Zhang     }
72971438e86SJunchao Zhang   }
7309566063dSJacob Faibussowitsch   PetscCall(PetscNew(&link));
7319566063dSJacob Faibussowitsch   PetscCall(PetscSFLinkSetUp_Host(sf, link, unit));                                          /* Compute link->unitbytes, dup link->unit etc. */
7329566063dSJacob Faibussowitsch   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf, link, unit)); /* Setup pack routines, streams etc */
73371438e86SJunchao Zhang #if defined(PETSC_HAVE_KOKKOS)
7349566063dSJacob Faibussowitsch   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf, link, unit));
73571438e86SJunchao Zhang #endif
73671438e86SJunchao Zhang 
73771438e86SJunchao Zhang   link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
73871438e86SJunchao Zhang   link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;
73971438e86SJunchao Zhang 
74071438e86SJunchao Zhang   /* Init signals to zero */
7419566063dSJacob Faibussowitsch   if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig));
7429566063dSJacob Faibussowitsch   if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig));
7439566063dSJacob Faibussowitsch   if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig));
7449566063dSJacob Faibussowitsch   if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig));
74571438e86SJunchao Zhang 
74671438e86SJunchao Zhang   link->use_nvshmem = PETSC_TRUE;
74771438e86SJunchao Zhang   link->rootmtype   = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
74871438e86SJunchao Zhang   link->leafmtype   = PETSC_MEMTYPE_DEVICE;
74971438e86SJunchao Zhang   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
75071438e86SJunchao Zhang   link->Destroy = PetscSFLinkDestroy_NVSHMEM;
75171438e86SJunchao Zhang   if (sf->use_nvshmem_get) { /* get-based protocol */
75271438e86SJunchao Zhang     link->PrePack             = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
75371438e86SJunchao Zhang     link->StartCommunication  = PetscSFLinkGetDataBegin_NVSHMEM;
75471438e86SJunchao Zhang     link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
75571438e86SJunchao Zhang   } else { /* put-based protocol */
75671438e86SJunchao Zhang     link->StartCommunication  = PetscSFLinkPutDataBegin_NVSHMEM;
75771438e86SJunchao Zhang     link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
75871438e86SJunchao Zhang     link->PostUnpack          = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
75971438e86SJunchao Zhang   }
76071438e86SJunchao Zhang 
7619566063dSJacob Faibussowitsch   PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority));
7629566063dSJacob Faibussowitsch   PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority));
76371438e86SJunchao Zhang 
7649566063dSJacob Faibussowitsch   PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming));
7659566063dSJacob Faibussowitsch   PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming));
76671438e86SJunchao Zhang 
76771438e86SJunchao Zhang found:
76871438e86SJunchao Zhang   if (rootdirect[PETSCSF_REMOTE]) {
76971438e86SJunchao Zhang     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
77071438e86SJunchao Zhang   } else {
77148a46eb9SPierre Jolivet     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
77271438e86SJunchao Zhang     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
77371438e86SJunchao Zhang   }
77471438e86SJunchao Zhang 
77571438e86SJunchao Zhang   if (leafdirect[PETSCSF_REMOTE]) {
77671438e86SJunchao Zhang     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
77771438e86SJunchao Zhang   } else {
77848a46eb9SPierre Jolivet     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
77971438e86SJunchao Zhang     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
78071438e86SJunchao Zhang   }
78171438e86SJunchao Zhang 
78271438e86SJunchao Zhang   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
78371438e86SJunchao Zhang   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
78471438e86SJunchao Zhang   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
78571438e86SJunchao Zhang   link->leafdata                   = leafdata;
78671438e86SJunchao Zhang   link->next                       = bas->inuse;
78771438e86SJunchao Zhang   bas->inuse                       = link;
78871438e86SJunchao Zhang   *mylink                          = link;
7893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
79071438e86SJunchao Zhang }
79171438e86SJunchao Zhang 
79271438e86SJunchao Zhang #if defined(PETSC_USE_REAL_SINGLE)
PetscNvshmemSum(PetscInt count,float * dst,const float * src)793d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src)
794d71ae5a4SJacob Faibussowitsch {
79571438e86SJunchao Zhang   PetscMPIInt num; /* Assume nvshmem's int is MPI's int */
79671438e86SJunchao Zhang 
79771438e86SJunchao Zhang   PetscFunctionBegin;
7989566063dSJacob Faibussowitsch   PetscCall(PetscMPIIntCast(count, &num));
79971438e86SJunchao Zhang   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
8003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
80171438e86SJunchao Zhang }
80271438e86SJunchao Zhang 
PetscNvshmemMax(PetscInt count,float * dst,const float * src)803d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src)
804d71ae5a4SJacob Faibussowitsch {
80571438e86SJunchao Zhang   PetscMPIInt num;
80671438e86SJunchao Zhang 
80771438e86SJunchao Zhang   PetscFunctionBegin;
8089566063dSJacob Faibussowitsch   PetscCall(PetscMPIIntCast(count, &num));
80971438e86SJunchao Zhang   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
8103ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
81171438e86SJunchao Zhang }
81271438e86SJunchao Zhang #elif defined(PETSC_USE_REAL_DOUBLE)
PetscNvshmemSum(PetscInt count,double * dst,const double * src)813d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src)
814d71ae5a4SJacob Faibussowitsch {
81571438e86SJunchao Zhang   PetscMPIInt num;
81671438e86SJunchao Zhang 
81771438e86SJunchao Zhang   PetscFunctionBegin;
8189566063dSJacob Faibussowitsch   PetscCall(PetscMPIIntCast(count, &num));
81971438e86SJunchao Zhang   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
8203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
82171438e86SJunchao Zhang }
82271438e86SJunchao Zhang 
PetscNvshmemMax(PetscInt count,double * dst,const double * src)823d71ae5a4SJacob Faibussowitsch PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src)
824d71ae5a4SJacob Faibussowitsch {
82571438e86SJunchao Zhang   PetscMPIInt num;
82671438e86SJunchao Zhang 
82771438e86SJunchao Zhang   PetscFunctionBegin;
8289566063dSJacob Faibussowitsch   PetscCall(PetscMPIIntCast(count, &num));
82971438e86SJunchao Zhang   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
8303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
83171438e86SJunchao Zhang }
83271438e86SJunchao Zhang #endif
833