GADGET-4
myalltoall.cc
Go to the documentation of this file.
1 /*******************************************************************************
2  * \copyright This file is part of the GADGET4 N-body/SPH code developed
3  * \copyright by Volker Springel. Copyright (C) 2014-2020 by Volker Springel
4  * \copyright (vspringel@mpa-garching.mpg.de) and all contributing authors.
5  ******************************************************************************/
6 
12 #include "gadgetconfig.h"
13 
14 #include <math.h>
15 #include <mpi.h>
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <string.h>
19 
20 #include "../data/allvars.h"
21 #include "../data/dtypes.h"
22 #include "../data/mymalloc.h"
23 #include "../mpi_utils/mpi_utils.h"
24 
25 #define PCHAR(a) ((char *)a)
26 
27 /* This method prepares an Alltoallv computation.
28  sendcnt: must have as many entries as there are Tasks in comm
29  must be set
30  recvcnt: must have as many entries as there are Tasks in comm
31  will be set on return
32  rdispls: must have as many entries as there are Tasks in comm, or be NULL
33  if not NULL, will be set on return
34  method: use standard Alltoall() approach or one-sided approach
35  returns: number of entries needed in the recvbuf */
36 int myMPI_Alltoallv_new_prep(int *sendcnt, int *recvcnt, int *rdispls, MPI_Comm comm, int method)
37 {
38  int rank, nranks;
39  MPI_Comm_size(comm, &nranks);
40  MPI_Comm_rank(comm, &rank);
41 
42  if(method == 0 || method == 1)
43  MPI_Alltoall(sendcnt, 1, MPI_INT, recvcnt, 1, MPI_INT, comm);
44  else if(method == 10)
45  {
46  for(int i = 0; i < nranks; ++i)
47  recvcnt[i] = 0;
48  recvcnt[rank] = sendcnt[rank]; // local communication
49  MPI_Win win;
50  MPI_Win_create(recvcnt, nranks * sizeof(MPI_INT), sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
51  MPI_Win_fence(0, win);
52  for(int i = 1; i < nranks; ++i) // remote communication
53  {
54  int tgt = (rank + i) % nranks;
55  if(sendcnt[tgt] != 0)
56  MPI_Put(&sendcnt[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
57  }
58  MPI_Win_fence(0, win);
59  MPI_Win_free(&win);
60  }
61  else
62  Terminate("bad communication method");
63 
64  int total = 0;
65  for(int i = 0; i < nranks; ++i)
66  {
67  if(rdispls)
68  rdispls[i] = total;
69  total += recvcnt[i];
70  }
71  return total;
72 }
73 
74 void myMPI_Alltoallv_new(void *sendbuf, int *sendcnt, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnt, int *rdispls,
75  MPI_Datatype recvtype, MPI_Comm comm, int method)
76 {
77  int rank, nranks, itsz;
78  MPI_Comm_size(comm, &nranks);
79  MPI_Comm_rank(comm, &rank);
80  MPI_Type_size(sendtype, &itsz);
81  size_t tsz = itsz; // to enforce size_t data type in later computations
82 
83  if(method == 0) // standard Alltoallv
84  MPI_Alltoallv(sendbuf, sendcnt, sdispls, sendtype, recvbuf, recvcnt, rdispls, recvtype, comm);
85  else if(method == 1) // blocking sendrecv
86  {
87  if(sendtype != recvtype)
88  Terminate("bad MPI communication types");
89  int lptask = 1;
90  while(lptask < nranks)
91  lptask <<= 1;
92  int tag = 42;
93  MPI_Status status;
94 
95  if(recvcnt[rank] > 0) // local communication
96  memcpy(PCHAR(recvbuf) + tsz * rdispls[rank], PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
97 
98  for(int ngrp = 1; ngrp < lptask; ngrp++)
99  {
100  int otask = rank ^ ngrp;
101  if(otask < nranks)
102  if(sendcnt[otask] > 0 || recvcnt[otask] > 0)
103  MPI_Sendrecv(PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag,
104  PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, &status);
105  }
106  }
107  else if(method == 2) // asynchronous communication
108  {
109  if(sendtype != recvtype)
110  Terminate("bad MPI communication types");
111  int lptask = 1;
112  while(lptask < nranks)
113  lptask <<= 1;
114  int tag = 42;
115 
116  MPI_Request *requests = (MPI_Request *)Mem.mymalloc("requests", 2 * nranks * sizeof(MPI_Request));
117  int n_requests = 0;
118 
119  if(recvcnt[rank] > 0) // local communication
120  memcpy(PCHAR(recvbuf) + tsz * rdispls[rank], PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
121 
122  for(int ngrp = 1; ngrp < lptask; ngrp++)
123  {
124  int otask = rank ^ ngrp;
125  if(otask < nranks)
126  if(recvcnt[otask] > 0)
127  MPI_Irecv(PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, &requests[n_requests++]);
128  }
129 
130  for(int ngrp = 1; ngrp < lptask; ngrp++)
131  {
132  int otask = rank ^ ngrp;
133  if(otask < nranks)
134  if(sendcnt[otask] > 0)
135  MPI_Issend(PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag, comm, &requests[n_requests++]);
136  }
137 
138  MPI_Waitall(n_requests, requests, MPI_STATUSES_IGNORE);
139  Mem.myfree(requests);
140  }
141  else if(method == 10)
142  {
143  if(sendtype != recvtype)
144  Terminate("bad MPI communication types");
145  int *disp_at_sender = (int *)Mem.mymalloc("disp_at_sender", nranks * sizeof(int));
146  disp_at_sender[rank] = sdispls[rank];
147  MPI_Win win;
148  MPI_Win_create(sdispls, nranks * sizeof(MPI_INT), sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
149  MPI_Win_fence(0, win);
150  for(int i = 1; i < nranks; ++i)
151  {
152  int tgt = (rank + i) % nranks;
153  if(recvcnt[tgt] != 0)
154  MPI_Get(&disp_at_sender[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
155  }
156  MPI_Win_fence(0, win);
157  MPI_Win_free(&win);
158  if(recvcnt[rank] > 0) // first take care of local communication
159  memcpy(PCHAR(recvbuf) + tsz * rdispls[rank], PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
160  MPI_Win_create(sendbuf, (sdispls[nranks - 1] + sendcnt[nranks - 1]) * tsz, tsz, MPI_INFO_NULL, comm, &win);
161  MPI_Win_fence(0, win);
162  for(int i = 1; i < nranks; ++i) // now the rest, start with right neighbour
163  {
164  int tgt = (rank + i) % nranks;
165  if(recvcnt[tgt] != 0)
166  MPI_Get(PCHAR(recvbuf) + tsz * rdispls[tgt], recvcnt[tgt], sendtype, tgt, disp_at_sender[tgt], recvcnt[tgt], sendtype,
167  win);
168  }
169  MPI_Win_fence(0, win);
170  MPI_Win_free(&win);
171  Mem.myfree(disp_at_sender);
172  }
173  else
174  Terminate("bad communication method");
175 }
176 
177 void myMPI_Alltoallv(void *sendb, size_t *sendcounts, size_t *sdispls, void *recvb, size_t *recvcounts, size_t *rdispls, int len,
178  int big_flag, MPI_Comm comm)
179 {
180  char *sendbuf = (char *)sendb;
181  char *recvbuf = (char *)recvb;
182 
183  if(big_flag == 0)
184  {
185  int ntask;
186  MPI_Comm_size(comm, &ntask);
187 
188  int *scount = (int *)Mem.mymalloc("scount", ntask * sizeof(int));
189  int *rcount = (int *)Mem.mymalloc("rcount", ntask * sizeof(int));
190  int *soff = (int *)Mem.mymalloc("soff", ntask * sizeof(int));
191  int *roff = (int *)Mem.mymalloc("roff", ntask * sizeof(int));
192 
193  for(int i = 0; i < ntask; i++)
194  {
195  scount[i] = sendcounts[i] * len;
196  rcount[i] = recvcounts[i] * len;
197  soff[i] = sdispls[i] * len;
198  roff[i] = rdispls[i] * len;
199  }
200 
201  MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
202 
203  Mem.myfree(roff);
204  Mem.myfree(soff);
205  Mem.myfree(rcount);
206  Mem.myfree(scount);
207  }
208  else
209  {
210  /* here we definitely have some large messages. We default to the
211  * pair-wise protocol, which should be most robust anyway.
212  */
213  int ntask, thistask, ptask;
214  MPI_Comm_size(comm, &ntask);
215  MPI_Comm_rank(comm, &thistask);
216 
217  for(ptask = 0; ntask > (1 << ptask); ptask++)
218  ;
219 
220  for(int ngrp = 0; ngrp < (1 << ptask); ngrp++)
221  {
222  int target = thistask ^ ngrp;
223 
224  if(target < ntask)
225  {
226  if(sendcounts[target] > 0 || recvcounts[target] > 0)
227  myMPI_Sendrecv(sendbuf + sdispls[target] * len, sendcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp,
228  recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp, comm,
229  MPI_STATUS_IGNORE);
230  }
231  }
232  }
233 }
234 
235 void my_int_MPI_Alltoallv(void *sendb, int *sendcounts, int *sdispls, void *recvb, int *recvcounts, int *rdispls, int len,
236  int big_flag, MPI_Comm comm)
237 {
238  char *sendbuf = (char *)sendb;
239  char *recvbuf = (char *)recvb;
240 
241  if(big_flag == 0)
242  {
243  int ntask;
244  MPI_Comm_size(comm, &ntask);
245 
246  int *scount = (int *)Mem.mymalloc("scount", ntask * sizeof(int));
247  int *rcount = (int *)Mem.mymalloc("rcount", ntask * sizeof(int));
248  int *soff = (int *)Mem.mymalloc("soff", ntask * sizeof(int));
249  int *roff = (int *)Mem.mymalloc("roff", ntask * sizeof(int));
250 
251  for(int i = 0; i < ntask; i++)
252  {
253  scount[i] = sendcounts[i] * len;
254  rcount[i] = recvcounts[i] * len;
255  soff[i] = sdispls[i] * len;
256  roff[i] = rdispls[i] * len;
257  }
258 
259  MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
260 
261  Mem.myfree(roff);
262  Mem.myfree(soff);
263  Mem.myfree(rcount);
264  Mem.myfree(scount);
265  }
266  else
267  {
268  /* here we definitely have some large messages. We default to the
269  * pair-wise protocoll, which should be most robust anyway.
270  */
271  int ntask, thistask, ptask;
272  MPI_Comm_size(comm, &ntask);
273  MPI_Comm_rank(comm, &thistask);
274 
275  for(ptask = 0; ntask > (1 << ptask); ptask++)
276  ;
277 
278  for(int ngrp = 0; ngrp < (1 << ptask); ngrp++)
279  {
280  int target = thistask ^ ngrp;
281 
282  if(target < ntask)
283  {
284  if(sendcounts[target] > 0 || recvcounts[target] > 0)
285  myMPI_Sendrecv(sendbuf + sdispls[target] * len, sendcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp,
286  recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp, comm,
287  MPI_STATUS_IGNORE);
288  }
289  }
290  }
291 }
#define Terminate(...)
Definition: macros.h:19
#define TAG_PDATA
Definition: mpi_utils.h:27
int myMPI_Sendrecv(void *sendbuf, size_t sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, size_t recvcount, MPI_Datatype recvtype, int source, int recvtag, MPI_Comm comm, MPI_Status *status)
void my_int_MPI_Alltoallv(void *sendb, int *sendcounts, int *sdispls, void *recvb, int *recvcounts, int *rdispls, int len, int big_flag, MPI_Comm comm)
Definition: myalltoall.cc:235
void myMPI_Alltoallv(void *sendb, size_t *sendcounts, size_t *sdispls, void *recvb, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
Definition: myalltoall.cc:177
int myMPI_Alltoallv_new_prep(int *sendcnt, int *recvcnt, int *rdispls, MPI_Comm comm, int method)
Definition: myalltoall.cc:36
void myMPI_Alltoallv_new(void *sendbuf, int *sendcnt, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnt, int *rdispls, MPI_Datatype recvtype, MPI_Comm comm, int method)
Definition: myalltoall.cc:74
#define PCHAR(a)
Definition: myalltoall.cc:25
memory Mem
Definition: main.cc:44