xref: /petsc/src/vec/is/sf/impls/basic/alltoall/sfalltoall.c (revision be37439ebbbdb2f81c3420c175a94aa72e59929c)
1 #include <../src/vec/is/sf/impls/basic/allgatherv/sfallgatherv.h>
2 #include <../src/vec/is/sf/impls/basic/allgather/sfallgather.h>
3 #include <../src/vec/is/sf/impls/basic/gatherv/sfgatherv.h>
4 
5 /* Reuse the type. The difference is some fields (i.e., displs, recvcounts) are not used, which is not a big deal */
6 typedef PetscSF_Allgatherv PetscSF_Alltoall;
7 
PetscSFLinkStartCommunication_Alltoall(PetscSF sf,PetscSFLink link,PetscSFDirection direction)8 static PetscErrorCode PetscSFLinkStartCommunication_Alltoall(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
9 {
10   MPI_Comm     comm    = MPI_COMM_NULL;
11   void        *rootbuf = NULL, *leafbuf = NULL;
12   MPI_Request *req  = NULL;
13   MPI_Datatype unit = link->unit;
14 
15   PetscFunctionBegin;
16   if (direction == PETSCSF_ROOT2LEAF) {
17     PetscCall(PetscSFLinkCopyRootBufferInCaseNotUseGpuAwareMPI(sf, link, PETSC_TRUE /* device2host before sending */));
18   } else {
19     PetscCall(PetscSFLinkCopyLeafBufferInCaseNotUseGpuAwareMPI(sf, link, PETSC_TRUE /* device2host */));
20   }
21   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
22   PetscCall(PetscSFLinkGetMPIBuffersAndRequests(sf, link, direction, &rootbuf, &leafbuf, &req, NULL));
23   PetscCall(PetscSFLinkSyncStreamBeforeCallMPI(sf, link));
24 
25   if (direction == PETSCSF_ROOT2LEAF) {
26     PetscCallMPI(MPIU_Ialltoall(rootbuf, 1, unit, leafbuf, 1, unit, comm, req));
27   } else {
28     PetscCallMPI(MPIU_Ialltoall(leafbuf, 1, unit, rootbuf, 1, unit, comm, req));
29   }
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
PetscSFSetCommunicationOps_Alltoall(PetscSF sf,PetscSFLink link)33 static PetscErrorCode PetscSFSetCommunicationOps_Alltoall(PetscSF sf, PetscSFLink link)
34 {
35   PetscFunctionBegin;
36   link->StartCommunication = PetscSFLinkStartCommunication_Alltoall;
37   PetscFunctionReturn(PETSC_SUCCESS);
38 }
39 
40 /*===================================================================================*/
41 /*              Implementations of SF public APIs                                    */
42 /*===================================================================================*/
PetscSFGetGraph_Alltoall(PetscSF sf,PetscInt * nroots,PetscInt * nleaves,const PetscInt ** ilocal,const PetscSFNode ** iremote)43 static PetscErrorCode PetscSFGetGraph_Alltoall(PetscSF sf, PetscInt *nroots, PetscInt *nleaves, const PetscInt **ilocal, const PetscSFNode **iremote)
44 {
45   PetscInt i;
46 
47   PetscFunctionBegin;
48   if (nroots) *nroots = sf->nroots;
49   if (nleaves) *nleaves = sf->nleaves;
50   if (ilocal) *ilocal = NULL; /* Contiguous local indices */
51   if (iremote) {
52     if (!sf->remote) {
53       PetscCall(PetscMalloc1(sf->nleaves, &sf->remote));
54       sf->remote_alloc = sf->remote;
55       for (i = 0; i < sf->nleaves; i++) {
56         sf->remote[i].rank  = i; /* this is nonsense, cannot be larger than size */
57         sf->remote[i].index = i;
58       }
59     }
60     *iremote = sf->remote;
61   }
62   PetscFunctionReturn(PETSC_SUCCESS);
63 }
64 
PetscSFCreateLocalSF_Alltoall(PetscSF sf,PetscSF * out)65 static PetscErrorCode PetscSFCreateLocalSF_Alltoall(PetscSF sf, PetscSF *out)
66 {
67   PetscInt     nroots = 1, nleaves = 1, *ilocal;
68   PetscSFNode *iremote = NULL;
69   PetscSF      lsf;
70   PetscMPIInt  rank;
71 
72   PetscFunctionBegin;
73   nroots  = 1;
74   nleaves = 1;
75   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)sf), &rank));
76   PetscCall(PetscMalloc1(nleaves, &ilocal));
77   PetscCall(PetscMalloc1(nleaves, &iremote));
78   ilocal[0]        = rank;
79   iremote[0].rank  = 0;    /* rank in PETSC_COMM_SELF */
80   iremote[0].index = rank; /* LocalSF is an embedded SF. Indices are not remapped */
81 
82   PetscCall(PetscSFCreate(PETSC_COMM_SELF, &lsf));
83   PetscCall(PetscSFSetGraph(lsf, nroots, nleaves, NULL /*contiguous leaves*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
84   PetscCall(PetscSFSetUp(lsf));
85   *out = lsf;
86   PetscFunctionReturn(PETSC_SUCCESS);
87 }
88 
PetscSFCreateEmbeddedRootSF_Alltoall(PetscSF sf,PetscInt nselected,const PetscInt * selected,PetscSF * newsf)89 static PetscErrorCode PetscSFCreateEmbeddedRootSF_Alltoall(PetscSF sf, PetscInt nselected, const PetscInt *selected, PetscSF *newsf)
90 {
91   PetscInt       i, *tmproots, *ilocal;
92   PetscSFNode   *iremote;
93   PetscMPIInt    nroots, *roots, nleaves, *leaves, rank, ndiranks, ndranks;
94   MPI_Comm       comm;
95   PetscSF_Basic *bas;
96   PetscSF        esf;
97 
98   PetscFunctionBegin;
99   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
100   PetscCallMPI(MPI_Comm_rank(comm, &rank));
101 
102   /* Uniq selected[] and store the result in roots[] */
103   PetscCall(PetscMalloc1(nselected, &tmproots));
104   PetscCall(PetscArraycpy(tmproots, selected, nselected));
105   PetscCall(PetscSortRemoveDupsInt(&nselected, tmproots)); /* nselected might be changed */
106   PetscCheck(tmproots[0] >= 0 && tmproots[nselected - 1] < sf->nroots, comm, PETSC_ERR_ARG_OUTOFRANGE, "Min/Max root indices %" PetscInt_FMT "/%" PetscInt_FMT " are not in [0,%" PetscInt_FMT ")", tmproots[0], tmproots[nselected - 1], sf->nroots);
107   PetscCall(PetscMPIIntCast(nselected, &nroots));
108   PetscCall(PetscMalloc1(nselected, &roots));
109   for (PetscMPIInt i = 0; i < nroots; i++) PetscCall(PetscMPIIntCast(tmproots[i], &roots[i]));
110   PetscCall(PetscFree(tmproots));
111 
112   /* Find out which leaves are still connected to roots in the embedded sf. Expect PetscCommBuildTwoSided is more scalable than MPI_Alltoall */
113   PetscCall(PetscCommBuildTwoSided(comm, 0 /*empty msg*/, MPI_INT /*fake*/, nroots, roots, NULL /*todata*/, &nleaves, &leaves, NULL /*fromdata*/));
114 
115   /* Move myself ahead if rank is in leaves[], since I am a distinguished rank */
116   ndranks = 0;
117   for (i = 0; i < nleaves; i++) {
118     if (leaves[i] == rank) {
119       leaves[i] = -rank;
120       ndranks   = 1;
121       break;
122     }
123   }
124   PetscCall(PetscSortMPIInt(nleaves, leaves));
125   if (nleaves && leaves[0] < 0) leaves[0] = rank;
126 
127   /* Build esf and fill its fields manually (without calling PetscSFSetUp) */
128   PetscCall(PetscMalloc1(nleaves, &ilocal));
129   PetscCall(PetscMalloc1(nleaves, &iremote));
130   for (i = 0; i < nleaves; i++) { /* 1:1 map from roots to leaves */
131     ilocal[i]        = leaves[i];
132     iremote[i].rank  = leaves[i];
133     iremote[i].index = leaves[i];
134   }
135   PetscCall(PetscSFCreate(comm, &esf));
136   PetscCall(PetscSFSetType(esf, PETSCSFBASIC)); /* This optimized routine can only create a basic sf */
137   PetscCall(PetscSFSetGraph(esf, sf->nleaves, nleaves, ilocal, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
138 
139   /* As if we are calling PetscSFSetUpRanks(esf,self's group) */
140   PetscCall(PetscMalloc4(nleaves, &esf->ranks, nleaves + 1, &esf->roffset, nleaves, &esf->rmine, nleaves, &esf->rremote));
141   esf->nranks     = nleaves;
142   esf->ndranks    = ndranks;
143   esf->roffset[0] = 0;
144   for (i = 0; i < nleaves; i++) {
145     esf->ranks[i]       = leaves[i];
146     esf->roffset[i + 1] = i + 1;
147     esf->rmine[i]       = leaves[i];
148     esf->rremote[i]     = leaves[i];
149   }
150 
151   /* Set up esf->data, the incoming communication (i.e., recv info), which is usually done by PetscSFSetUp_Basic */
152   bas = (PetscSF_Basic *)esf->data;
153   PetscCall(PetscMalloc2(nroots, &bas->iranks, nroots + 1, &bas->ioffset));
154   PetscCall(PetscMalloc1(nroots, &bas->irootloc));
155   /* Move myself ahead if rank is in roots[], since I am a distinguished irank */
156   ndiranks = 0;
157   for (i = 0; i < nroots; i++) {
158     if (roots[i] == rank) {
159       roots[i] = -rank;
160       ndiranks = 1;
161       break;
162     }
163   }
164   PetscCall(PetscSortMPIInt(nroots, roots));
165   if (nroots && roots[0] < 0) roots[0] = rank;
166 
167   bas->niranks    = nroots;
168   bas->ndiranks   = ndiranks;
169   bas->ioffset[0] = 0;
170   bas->itotal     = nroots;
171   for (i = 0; i < nroots; i++) {
172     bas->iranks[i]      = roots[i];
173     bas->ioffset[i + 1] = i + 1;
174     bas->irootloc[i]    = roots[i];
175   }
176 
177   /* See PetscSFCreateEmbeddedRootSF_Basic */
178   esf->nleafreqs  = esf->nranks - esf->ndranks;
179   bas->nrootreqs  = bas->niranks - bas->ndiranks;
180   esf->persistent = PETSC_TRUE;
181   /* Setup packing related fields */
182   PetscCall(PetscSFSetUpPackFields(esf));
183 
184   esf->setupcalled = PETSC_TRUE; /* We have done setup ourselves! */
185   *newsf           = esf;
186   PetscFunctionReturn(PETSC_SUCCESS);
187 }
188 
PetscSFCreate_Alltoall(PetscSF sf)189 PETSC_INTERN PetscErrorCode PetscSFCreate_Alltoall(PetscSF sf)
190 {
191   PetscSF_Alltoall *dat = (PetscSF_Alltoall *)sf->data;
192 
193   PetscFunctionBegin;
194   sf->ops->BcastBegin  = PetscSFBcastBegin_Basic;
195   sf->ops->BcastEnd    = PetscSFBcastEnd_Basic;
196   sf->ops->ReduceBegin = PetscSFReduceBegin_Basic;
197   sf->ops->ReduceEnd   = PetscSFReduceEnd_Basic;
198 
199   /* Inherit from Allgatherv. It is astonishing Alltoall can inherit so much from Allgather(v) */
200   sf->ops->Destroy       = PetscSFDestroy_Allgatherv;
201   sf->ops->Reset         = PetscSFReset_Allgatherv;
202   sf->ops->FetchAndOpEnd = PetscSFFetchAndOpEnd_Allgatherv;
203   sf->ops->GetRootRanks  = PetscSFGetRootRanks_Allgatherv;
204 
205   /* Inherit from Allgather. Every process gathers equal-sized data from others, which enables this inheritance. */
206   sf->ops->GetLeafRanks = PetscSFGetLeafRanks_Allgatherv;
207   sf->ops->SetUp        = PetscSFSetUp_Allgather;
208 
209   /* Inherit from Gatherv. Each root has only one leaf connected, which enables this inheritance */
210   sf->ops->FetchAndOpBegin = PetscSFFetchAndOpBegin_Gatherv;
211 
212   /* Alltoall stuff */
213   sf->ops->GetGraph             = PetscSFGetGraph_Alltoall;
214   sf->ops->CreateLocalSF        = PetscSFCreateLocalSF_Alltoall;
215   sf->ops->CreateEmbeddedRootSF = PetscSFCreateEmbeddedRootSF_Alltoall;
216 
217   sf->ops->SetCommunicationOps = PetscSFSetCommunicationOps_Alltoall;
218 
219   sf->collective = PETSC_TRUE;
220 
221   PetscCall(PetscNew(&dat));
222   sf->data = (void *)dat;
223   PetscFunctionReturn(PETSC_SUCCESS);
224 }
225