xref: /petsc/src/sys/utils/mpimesg.c (revision 4dfa11a44d5adf2389f1d3acbc8f3c1116dc6c3a)
1 
2 #include <petscsys.h> /*I  "petscsys.h"  I*/
3 #include <petsc/private/mpiutils.h>
4 
5 /*@C
6   PetscGatherNumberOfMessages -  Computes the number of messages an MPI rank expects to receive during a neighbor communication
7 
8   Collective
9 
10   Input Parameters:
11 + comm     - Communicator
12 . iflags   - an array of integers of length sizeof(comm). A '1' in ilengths[i] represent a
13              message from current node to ith node. Optionally NULL
14 - ilengths - Non zero ilengths[i] represent a message to i of length ilengths[i].
15              Optionally NULL.
16 
17   Output Parameters:
18 . nrecvs    - number of messages received
19 
20   Level: developer
21 
22   Notes:
23   With this info, the correct message lengths can be determined using
24   `PetscGatherMessageLengths()`
25 
26   Either iflags or ilengths should be provided.  If iflags is not
27   provided (NULL) it can be computed from ilengths. If iflags is
28   provided, ilengths is not required.
29 
30 .seealso: `PetscGatherMessageLengths()`, `PetscGatherMessageLengths2()`, `PetscCommBuildTwoSided()`
31 @*/
32 PetscErrorCode PetscGatherNumberOfMessages(MPI_Comm comm, const PetscMPIInt iflags[], const PetscMPIInt ilengths[], PetscMPIInt *nrecvs) {
33   PetscMPIInt size, rank, *recv_buf, i, *iflags_local = NULL, *iflags_localm = NULL;
34 
35   PetscFunctionBegin;
36   PetscCallMPI(MPI_Comm_size(comm, &size));
37   PetscCallMPI(MPI_Comm_rank(comm, &rank));
38 
39   PetscCall(PetscMalloc2(size, &recv_buf, size, &iflags_localm));
40 
41   /* If iflags not provided, compute iflags from ilengths */
42   if (!iflags) {
43     PetscCheck(ilengths, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Either iflags or ilengths should be provided");
44     iflags_local = iflags_localm;
45     for (i = 0; i < size; i++) {
46       if (ilengths[i]) iflags_local[i] = 1;
47       else iflags_local[i] = 0;
48     }
49   } else iflags_local = (PetscMPIInt *)iflags;
50 
51   /* Post an allreduce to determine the numer of messages the current node will receive */
52   PetscCall(MPIU_Allreduce(iflags_local, recv_buf, size, MPI_INT, MPI_SUM, comm));
53   *nrecvs = recv_buf[rank];
54 
55   PetscCall(PetscFree2(recv_buf, iflags_localm));
56   PetscFunctionReturn(0);
57 }
58 
59 /*@C
60   PetscGatherMessageLengths - Computes infomation about messages that an MPI rank will receive,
61   including (from-id,length) pairs for each message.
62 
63   Collective
64 
65   Input Parameters:
66 + comm      - Communicator
67 . nsends    - number of messages that are to be sent.
68 . nrecvs    - number of messages being received
69 - ilengths  - an array of integers of length sizeof(comm)
70               a non zero ilengths[i] represent a message to i of length ilengths[i]
71 
72   Output Parameters:
73 + onodes    - list of node-ids from which messages are expected
74 - olengths  - corresponding message lengths
75 
76   Level: developer
77 
78   Notes:
79   With this info, the correct `MPI_Irecv()` can be posted with the correct
80   from-id, with a buffer with the right amount of memory required.
81 
82   The calling function deallocates the memory in onodes and olengths
83 
84   To determine nrecvs, one can use `PetscGatherNumberOfMessages()`
85 
86 .seealso: `PetscGatherNumberOfMessages()`, `PetscGatherMessageLengths2()`, `PetscCommBuildTwoSided()`
87 @*/
88 PetscErrorCode PetscGatherMessageLengths(MPI_Comm comm, PetscMPIInt nsends, PetscMPIInt nrecvs, const PetscMPIInt ilengths[], PetscMPIInt **onodes, PetscMPIInt **olengths) {
89   PetscMPIInt  size, rank, tag, i, j;
90   MPI_Request *s_waits = NULL, *r_waits = NULL;
91   MPI_Status  *w_status = NULL;
92 
93   PetscFunctionBegin;
94   PetscCallMPI(MPI_Comm_size(comm, &size));
95   PetscCallMPI(MPI_Comm_rank(comm, &rank));
96   PetscCall(PetscCommGetNewTag(comm, &tag));
97 
98   /* cannot use PetscMalloc3() here because in the call to MPI_Waitall() they MUST be contiguous */
99   PetscCall(PetscMalloc2(nrecvs + nsends, &r_waits, nrecvs + nsends, &w_status));
100   s_waits = r_waits + nrecvs;
101 
102   /* Post the Irecv to get the message length-info */
103   PetscCall(PetscMalloc1(nrecvs, olengths));
104   for (i = 0; i < nrecvs; i++) PetscCallMPI(MPI_Irecv((*olengths) + i, 1, MPI_INT, MPI_ANY_SOURCE, tag, comm, r_waits + i));
105 
106   /* Post the Isends with the message length-info */
107   for (i = 0, j = 0; i < size; ++i) {
108     if (ilengths[i]) {
109       PetscCallMPI(MPI_Isend((void *)(ilengths + i), 1, MPI_INT, i, tag, comm, s_waits + j));
110       j++;
111     }
112   }
113 
114   /* Post waits on sends and receives */
115   if (nrecvs + nsends) PetscCallMPI(MPI_Waitall(nrecvs + nsends, r_waits, w_status));
116 
117   /* Pack up the received data */
118   PetscCall(PetscMalloc1(nrecvs, onodes));
119   for (i = 0; i < nrecvs; ++i) {
120     (*onodes)[i] = w_status[i].MPI_SOURCE;
121 #if defined(PETSC_HAVE_OMPI_MAJOR_VERSION)
122     /* This line is a workaround for a bug in OpenMPI-2.1.1 distributed by Ubuntu-18.04.2 LTS.
123        It happens in self-to-self MPI_Send/Recv using MPI_ANY_SOURCE for message matching. OpenMPI
124        does not put correct value in recv buffer. See also
125        https://lists.mcs.anl.gov/pipermail/petsc-dev/2019-July/024803.html
126        https://www.mail-archive.com/users@lists.open-mpi.org//msg33383.html
127      */
128     if (w_status[i].MPI_SOURCE == rank) (*olengths)[i] = ilengths[rank];
129 #endif
130   }
131   PetscCall(PetscFree2(r_waits, w_status));
132   PetscFunctionReturn(0);
133 }
134 
135 /* Same as PetscGatherNumberOfMessages(), except using PetscInt for ilengths[] */
136 PetscErrorCode PetscGatherNumberOfMessages_Private(MPI_Comm comm, const PetscMPIInt iflags[], const PetscInt ilengths[], PetscMPIInt *nrecvs) {
137   PetscMPIInt size, rank, *recv_buf, i, *iflags_local = NULL, *iflags_localm = NULL;
138 
139   PetscFunctionBegin;
140   PetscCallMPI(MPI_Comm_size(comm, &size));
141   PetscCallMPI(MPI_Comm_rank(comm, &rank));
142 
143   PetscCall(PetscMalloc2(size, &recv_buf, size, &iflags_localm));
144 
145   /* If iflags not provided, compute iflags from ilengths */
146   if (!iflags) {
147     PetscCheck(ilengths, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Either iflags or ilengths should be provided");
148     iflags_local = iflags_localm;
149     for (i = 0; i < size; i++) {
150       if (ilengths[i]) iflags_local[i] = 1;
151       else iflags_local[i] = 0;
152     }
153   } else iflags_local = (PetscMPIInt *)iflags;
154 
155   /* Post an allreduce to determine the numer of messages the current node will receive */
156   PetscCall(MPIU_Allreduce(iflags_local, recv_buf, size, MPI_INT, MPI_SUM, comm));
157   *nrecvs = recv_buf[rank];
158 
159   PetscCall(PetscFree2(recv_buf, iflags_localm));
160   PetscFunctionReturn(0);
161 }
162 
163 /* Same as PetscGatherMessageLengths(), except using PetscInt for message lengths */
164 PetscErrorCode PetscGatherMessageLengths_Private(MPI_Comm comm, PetscMPIInt nsends, PetscMPIInt nrecvs, const PetscInt ilengths[], PetscMPIInt **onodes, PetscInt **olengths) {
165   PetscMPIInt  size, rank, tag, i, j;
166   MPI_Request *s_waits = NULL, *r_waits = NULL;
167   MPI_Status  *w_status = NULL;
168 
169   PetscFunctionBegin;
170   PetscCallMPI(MPI_Comm_size(comm, &size));
171   PetscCallMPI(MPI_Comm_rank(comm, &rank));
172   PetscCall(PetscCommGetNewTag(comm, &tag));
173 
174   /* cannot use PetscMalloc3() here because in the call to MPI_Waitall() they MUST be contiguous */
175   PetscCall(PetscMalloc2(nrecvs + nsends, &r_waits, nrecvs + nsends, &w_status));
176   s_waits = r_waits + nrecvs;
177 
178   /* Post the Irecv to get the message length-info */
179   PetscCall(PetscMalloc1(nrecvs, olengths));
180   for (i = 0; i < nrecvs; i++) PetscCallMPI(MPI_Irecv((*olengths) + i, 1, MPIU_INT, MPI_ANY_SOURCE, tag, comm, r_waits + i));
181 
182   /* Post the Isends with the message length-info */
183   for (i = 0, j = 0; i < size; ++i) {
184     if (ilengths[i]) {
185       PetscCallMPI(MPI_Isend((void *)(ilengths + i), 1, MPIU_INT, i, tag, comm, s_waits + j));
186       j++;
187     }
188   }
189 
190   /* Post waits on sends and receives */
191   if (nrecvs + nsends) PetscCallMPI(MPI_Waitall(nrecvs + nsends, r_waits, w_status));
192 
193   /* Pack up the received data */
194   PetscCall(PetscMalloc1(nrecvs, onodes));
195   for (i = 0; i < nrecvs; ++i) {
196     (*onodes)[i] = w_status[i].MPI_SOURCE;
197     if (w_status[i].MPI_SOURCE == rank) (*olengths)[i] = ilengths[rank]; /* See comments in PetscGatherMessageLengths */
198   }
199   PetscCall(PetscFree2(r_waits, w_status));
200   PetscFunctionReturn(0);
201 }
202 
203 /*@C
204   PetscGatherMessageLengths2 - Computes info about messages that a MPI rank will receive,
205   including (from-id,length) pairs for each message. Same functionality as `PetscGatherMessageLengths()`
206   except it takes TWO ilenths and output TWO olengths.
207 
208   Collective
209 
210   Input Parameters:
211 + comm      - Communicator
212 . nsends    - number of messages that are to be sent.
213 . nrecvs    - number of messages being received
214 . ilengths1 - first array of integers of length sizeof(comm)
215 - ilengths2 - second array of integers of length sizeof(comm)
216 
217   Output Parameters:
218 + onodes    - list of node-ids from which messages are expected
219 . olengths1 - first corresponding message lengths
220 - olengths2 - second  message lengths
221 
222   Level: developer
223 
224   Notes:
225   With this info, the correct `MPI_Irecv()` can be posted with the correct
226   from-id, with a buffer with the right amount of memory required.
227 
228   The calling function deallocates the memory in onodes and olengths
229 
230   To determine nrecvs, one can use PetscGatherNumberOfMessages()
231 
232 .seealso: `PetscGatherMessageLengths()`, `PetscGatherNumberOfMessages()`, `PetscCommBuildTwoSided()`
233 @*/
234 PetscErrorCode PetscGatherMessageLengths2(MPI_Comm comm, PetscMPIInt nsends, PetscMPIInt nrecvs, const PetscMPIInt ilengths1[], const PetscMPIInt ilengths2[], PetscMPIInt **onodes, PetscMPIInt **olengths1, PetscMPIInt **olengths2) {
235   PetscMPIInt  size, tag, i, j, *buf_s = NULL, *buf_r = NULL, *buf_j = NULL;
236   MPI_Request *s_waits = NULL, *r_waits = NULL;
237   MPI_Status  *w_status = NULL;
238 
239   PetscFunctionBegin;
240   PetscCallMPI(MPI_Comm_size(comm, &size));
241   PetscCall(PetscCommGetNewTag(comm, &tag));
242 
243   /* cannot use PetscMalloc5() because r_waits and s_waits must be contiguous for the call to MPI_Waitall() */
244   PetscCall(PetscMalloc4(nrecvs + nsends, &r_waits, 2 * nrecvs, &buf_r, 2 * nsends, &buf_s, nrecvs + nsends, &w_status));
245   s_waits = r_waits + nrecvs;
246 
247   /* Post the Irecv to get the message length-info */
248   PetscCall(PetscMalloc1(nrecvs + 1, olengths1));
249   PetscCall(PetscMalloc1(nrecvs + 1, olengths2));
250   for (i = 0; i < nrecvs; i++) {
251     buf_j = buf_r + (2 * i);
252     PetscCallMPI(MPI_Irecv(buf_j, 2, MPI_INT, MPI_ANY_SOURCE, tag, comm, r_waits + i));
253   }
254 
255   /* Post the Isends with the message length-info */
256   for (i = 0, j = 0; i < size; ++i) {
257     if (ilengths1[i]) {
258       buf_j    = buf_s + (2 * j);
259       buf_j[0] = *(ilengths1 + i);
260       buf_j[1] = *(ilengths2 + i);
261       PetscCallMPI(MPI_Isend(buf_j, 2, MPI_INT, i, tag, comm, s_waits + j));
262       j++;
263     }
264   }
265   PetscCheck(j == nsends, PETSC_COMM_SELF, PETSC_ERR_PLIB, "j %d not equal to expected number of sends %d", j, nsends);
266 
267   /* Post waits on sends and receives */
268   if (nrecvs + nsends) PetscCallMPI(MPI_Waitall(nrecvs + nsends, r_waits, w_status));
269 
270   /* Pack up the received data */
271   PetscCall(PetscMalloc1(nrecvs + 1, onodes));
272   for (i = 0; i < nrecvs; ++i) {
273     (*onodes)[i]    = w_status[i].MPI_SOURCE;
274     buf_j           = buf_r + (2 * i);
275     (*olengths1)[i] = buf_j[0];
276     (*olengths2)[i] = buf_j[1];
277   }
278 
279   PetscCall(PetscFree4(r_waits, buf_r, buf_s, w_status));
280   PetscFunctionReturn(0);
281 }
282 
283 /*
284   Allocate a buffer sufficient to hold messages of size specified in olengths.
285   And post Irecvs on these buffers using node info from onodes
286  */
287 PetscErrorCode PetscPostIrecvInt(MPI_Comm comm, PetscMPIInt tag, PetscMPIInt nrecvs, const PetscMPIInt onodes[], const PetscMPIInt olengths[], PetscInt ***rbuf, MPI_Request **r_waits) {
288   PetscInt   **rbuf_t, i, len = 0;
289   MPI_Request *r_waits_t;
290 
291   PetscFunctionBegin;
292   /* compute memory required for recv buffers */
293   for (i = 0; i < nrecvs; i++) len += olengths[i]; /* each message length */
294 
295   /* allocate memory for recv buffers */
296   PetscCall(PetscMalloc1(nrecvs + 1, &rbuf_t));
297   PetscCall(PetscMalloc1(len, &rbuf_t[0]));
298   for (i = 1; i < nrecvs; ++i) rbuf_t[i] = rbuf_t[i - 1] + olengths[i - 1];
299 
300   /* Post the receives */
301   PetscCall(PetscMalloc1(nrecvs, &r_waits_t));
302   for (i = 0; i < nrecvs; ++i) PetscCallMPI(MPI_Irecv(rbuf_t[i], olengths[i], MPIU_INT, onodes[i], tag, comm, r_waits_t + i));
303 
304   *rbuf    = rbuf_t;
305   *r_waits = r_waits_t;
306   PetscFunctionReturn(0);
307 }
308 
309 PetscErrorCode PetscPostIrecvScalar(MPI_Comm comm, PetscMPIInt tag, PetscMPIInt nrecvs, const PetscMPIInt onodes[], const PetscMPIInt olengths[], PetscScalar ***rbuf, MPI_Request **r_waits) {
310   PetscMPIInt   i;
311   PetscScalar **rbuf_t;
312   MPI_Request  *r_waits_t;
313   PetscInt      len = 0;
314 
315   PetscFunctionBegin;
316   /* compute memory required for recv buffers */
317   for (i = 0; i < nrecvs; i++) len += olengths[i]; /* each message length */
318 
319   /* allocate memory for recv buffers */
320   PetscCall(PetscMalloc1(nrecvs + 1, &rbuf_t));
321   PetscCall(PetscMalloc1(len, &rbuf_t[0]));
322   for (i = 1; i < nrecvs; ++i) rbuf_t[i] = rbuf_t[i - 1] + olengths[i - 1];
323 
324   /* Post the receives */
325   PetscCall(PetscMalloc1(nrecvs, &r_waits_t));
326   for (i = 0; i < nrecvs; ++i) PetscCallMPI(MPI_Irecv(rbuf_t[i], olengths[i], MPIU_SCALAR, onodes[i], tag, comm, r_waits_t + i));
327 
328   *rbuf    = rbuf_t;
329   *r_waits = r_waits_t;
330   PetscFunctionReturn(0);
331 }
332