1 #include <../src/vec/is/sf/impls/basic/gatherv/sfgatherv.h>
2 #include <../src/vec/is/sf/impls/basic/allgather/sfallgather.h>
3
4 /* Reuse the type. The difference is some fields (i.e., displs, recvcounts) are not used in Gather, which is not a big deal */
5 typedef PetscSF_Allgatherv PetscSF_Gather;
6
PetscSFLinkStartCommunication_Gather(PetscSF sf,PetscSFLink link,PetscSFDirection direction)7 static PetscErrorCode PetscSFLinkStartCommunication_Gather(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
8 {
9 MPI_Comm comm = MPI_COMM_NULL;
10 void *rootbuf = NULL, *leafbuf = NULL;
11 MPI_Request *req = NULL;
12 PetscMPIInt count;
13 MPI_Datatype unit = link->unit;
14
15 PetscFunctionBegin;
16 if (direction == PETSCSF_ROOT2LEAF) {
17 PetscCall(PetscSFLinkCopyRootBufferInCaseNotUseGpuAwareMPI(sf, link, PETSC_TRUE /* device2host before sending */));
18 } else {
19 PetscCall(PetscSFLinkCopyLeafBufferInCaseNotUseGpuAwareMPI(sf, link, PETSC_TRUE /* device2host */));
20 }
21 PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
22 PetscCall(PetscMPIIntCast(sf->nroots, &count));
23 PetscCall(PetscSFLinkGetMPIBuffersAndRequests(sf, link, direction, &rootbuf, &leafbuf, &req, NULL));
24 PetscCall(PetscSFLinkSyncStreamBeforeCallMPI(sf, link));
25
26 if (direction == PETSCSF_ROOT2LEAF) {
27 PetscCallMPI(MPIU_Igather(rootbuf == leafbuf ? MPI_IN_PLACE : rootbuf, count, unit, leafbuf, count, unit, 0 /*rank 0*/, comm, req));
28 } else {
29 PetscCallMPI(MPIU_Iscatter(leafbuf, count, unit, rootbuf == leafbuf ? MPI_IN_PLACE : rootbuf, count, unit, 0 /*rank 0*/, comm, req));
30 }
31 PetscFunctionReturn(PETSC_SUCCESS);
32 }
33
PetscSFSetCommunicationOps_Gather(PetscSF sf,PetscSFLink link)34 static PetscErrorCode PetscSFSetCommunicationOps_Gather(PetscSF sf, PetscSFLink link)
35 {
36 PetscFunctionBegin;
37 link->StartCommunication = PetscSFLinkStartCommunication_Gather;
38 PetscFunctionReturn(PETSC_SUCCESS);
39 }
40
PetscSFCreate_Gather(PetscSF sf)41 PETSC_INTERN PetscErrorCode PetscSFCreate_Gather(PetscSF sf)
42 {
43 PetscSF_Gather *dat = (PetscSF_Gather *)sf->data;
44
45 PetscFunctionBegin;
46 sf->ops->BcastBegin = PetscSFBcastBegin_Basic;
47 sf->ops->BcastEnd = PetscSFBcastEnd_Basic;
48 sf->ops->ReduceBegin = PetscSFReduceBegin_Basic;
49 sf->ops->ReduceEnd = PetscSFReduceEnd_Basic;
50
51 /* Inherit from Allgatherv */
52 sf->ops->Reset = PetscSFReset_Allgatherv;
53 sf->ops->Destroy = PetscSFDestroy_Allgatherv;
54 sf->ops->GetGraph = PetscSFGetGraph_Allgatherv;
55 sf->ops->GetRootRanks = PetscSFGetRootRanks_Allgatherv;
56 sf->ops->GetLeafRanks = PetscSFGetLeafRanks_Allgatherv;
57 sf->ops->FetchAndOpEnd = PetscSFFetchAndOpEnd_Allgatherv;
58 sf->ops->CreateLocalSF = PetscSFCreateLocalSF_Allgatherv;
59
60 /* Inherit from Allgather */
61 sf->ops->SetUp = PetscSFSetUp_Allgather;
62
63 /* Inherit from Gatherv */
64 sf->ops->FetchAndOpBegin = PetscSFFetchAndOpBegin_Gatherv;
65
66 sf->ops->SetCommunicationOps = PetscSFSetCommunicationOps_Gather;
67
68 sf->collective = PETSC_TRUE;
69
70 PetscCall(PetscNew(&dat));
71 sf->data = (void *)dat;
72 PetscFunctionReturn(PETSC_SUCCESS);
73 }
74