GADGET-4
parallel_sort.h
Go to the documentation of this file.
1 /*******************************************************************************
2  * \copyright This file is part of the GADGET4 N-body/SPH code developed
3  * \copyright by Volker Springel. Copyright (C) 2014-2020 by Volker Springel
4  * \copyright (vspringel@mpa-garching.mpg.de) and all contributing authors.
5  *******************************************************************************/
6 
12 #ifndef PARALLEL_SORT_H
13 #define PARALLEL_SORT_H
14 
15 #include "cxxsort.h"
16 
17 #include "../data/mymalloc.h"
18 
19 //#define CHECK_LOCAL_RANK
20 
21 template <typename It, typename Comp>
22 class IdxComp__
23 {
24  private:
25  It begin;
26  Comp comp;
27 
28  public:
29  IdxComp__(It begin_, Comp comp_) : begin(begin_), comp(comp_) {}
30  bool operator()(std::size_t a, std::size_t b) const { return comp(*(begin + a), *(begin + b)); }
31 };
32 
36 template <typename It, typename T2, typename Comp>
37 inline void buildIndex(It begin, It end, T2 *idx, Comp comp)
38 {
39  using namespace std;
40  T2 num = end - begin;
41  for(T2 i = 0; i < num; ++i)
42  idx[i] = i;
43  mycxxsort(idx, idx + num, IdxComp__<It, Comp>(begin, comp));
44 }
45 
46 template <typename T, typename Comp>
47 void get_local_rank(const T &element, std::size_t tie_breaking_rank, const T *base, size_t nmemb, size_t noffs_thistask,
48  long long left, long long right, size_t *loc, Comp comp)
49 {
50  if(right < left)
51  Terminate("right < left");
52 
53  if(left == 0 && right == (int)nmemb + 1)
54  {
55  if(comp(base[nmemb - 1], element))
56  {
57  *loc = nmemb;
58  return;
59  }
60  else if(comp(element, base[0]))
61  {
62  *loc = 0;
63  return;
64  }
65  }
66 
67  if(right == left) /* looks like we already converged to the proper rank */
68  {
69  *loc = left;
70  }
71  else
72  {
73  if(comp(base[right - 1], element)) /* the last element is smaller, hence all elements are on the left */
74  *loc = (right - 1) + 1;
75  else if(comp(element, base[left])) /* the first element is already larger, hence no element is on the left */
76  *loc = left;
77  else
78  {
79  while(right > left)
80  {
81  long long mid = ((right - 1) + left) / 2;
82 
83  int cmp = comp(base[mid], element) ? -1 : (comp(element, base[mid]) ? +1 : 0);
84  if(cmp == 0)
85  {
86  if(mid + noffs_thistask < tie_breaking_rank)
87  cmp = -1;
88  else if(mid + noffs_thistask > tie_breaking_rank)
89  cmp = +1;
90  }
91 
92  if(cmp == 0) /* element has exactly been found */
93  {
94  *loc = mid;
95  break;
96  }
97 
98  if((right - 1) == left) /* elements is not on this CPU */
99  {
100  if(cmp < 0)
101  *loc = mid + 1;
102  else
103  *loc = mid;
104  break;
105  }
106 
107  if(cmp < 0)
108  {
109  left = mid + 1;
110  }
111  else
112  {
113  if((right - 1) == left + 1)
114  {
115  if(mid != left)
116  Terminate("Can't be: -->left=%lld right=%lld\n", left, right);
117 
118  *loc = left;
119  break;
120  }
121 
122  right = mid;
123  }
124  }
125  }
126  }
127 }
128 
129 #ifdef CHECK_LOCAL_RANK
130 template <typename T, typename Comp>
131 inline void check_local_rank(const T &element, /* element of which we want the rank */
132  size_t tie_breaking_rank, /* the initial global rank of this element (needed for breaking ties) */
133  const T *base, /* base address of local data */
134  size_t nmemb, /* number and size of local data */
135  size_t noffs_thistask, /* cumulative length of data on lower tasks */
136  long long left, long long right, /* range of elements on local task that may hold the element */
137  size_t loc, Comp comp) /* user-specified comparison function */
138 {
139  long long count = 0;
140 
141  for(size_t i = 0; i < nmemb; i++)
142  {
143  int cmp = comp(base[i], element) ? -1 : (comp(element, base[i]) ? +1 : 0);
144 
145  if(cmp == 0)
146  {
147  if(noffs_thistask + i < tie_breaking_rank)
148  cmp = -1;
149  }
150 
151  if(cmp < 0)
152  count++;
153  }
154 
155  if(count != (long long)loc)
156  Terminate("Inconsistency: loc=%lld count=%lld left=%lld right=%lld nmemb=%lld\n", (long long)loc, count, left, right,
157  (long long)nmemb);
158 }
159 #endif
160 
161 template <typename T, typename Comp>
162 inline double mycxxsort_parallel(T *begin, T *end, Comp comp, MPI_Comm comm)
163 {
164  const int MAX_ITER_PARALLEL_SORT = 500;
165  int ranks_not_found, Local_ThisTask, Local_NTask, Color, new_max_loc;
166  size_t tie_breaking_rank, new_tie_breaking_rank, rank;
167  MPI_Comm MPI_CommLocal;
168 
169  double ta = Logs.second();
170  size_t nmemb = end - begin;
171  size_t size = sizeof(T);
172  /* do a serial sort of the local data up front */
173  mycxxsort(begin, end, comp);
174 
175  /* we create a communicator that contains just those tasks with nmemb > 0. This makes
176  * it easier to deal with CPUs that do not hold any data.
177  */
178  if(nmemb)
179  Color = 1;
180  else
181  Color = 0;
182 
183  int thistask;
184  MPI_Comm_rank(comm, &thistask);
185 
186  MPI_Comm_split(comm, Color, thistask, &MPI_CommLocal);
187  MPI_Comm_rank(MPI_CommLocal, &Local_ThisTask);
188  MPI_Comm_size(MPI_CommLocal, &Local_NTask);
189 
190  if(Local_NTask > 1 && Color == 1)
191  {
192  size_t *nlist = (size_t *)Mem.mymalloc("nlist", Local_NTask * sizeof(size_t));
193  size_t *noffs = (size_t *)Mem.mymalloc("noffs", Local_NTask * sizeof(size_t));
194 
195  MPI_Allgather(&nmemb, sizeof(size_t), MPI_BYTE, nlist, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
196 
197  noffs[0] = 0;
198  for(int i = 1; i < Local_NTask; i++)
199  noffs[i] = noffs[i - 1] + nlist[i - 1];
200 
201  T *element_guess = (T *)Mem.mymalloc("element_guess", Local_NTask * size);
202  size_t *element_tie_breaking_rank = (size_t *)Mem.mymalloc("element_tie_breaking_rank", Local_NTask * sizeof(size_t));
203  size_t *desired_glob_rank = (size_t *)Mem.mymalloc("desired_glob_rank", Local_NTask * sizeof(size_t));
204  size_t *current_glob_rank = (size_t *)Mem.mymalloc("current_glob_rank", Local_NTask * sizeof(size_t));
205  size_t *current_loc_rank = (size_t *)Mem.mymalloc("current_loc_rank", Local_NTask * sizeof(size_t));
206  long long *range_left = (long long *)Mem.mymalloc("range_left", Local_NTask * sizeof(long long));
207  long long *range_right = (long long *)Mem.mymalloc("range_right", Local_NTask * sizeof(long long));
208  int *max_loc = (int *)Mem.mymalloc("max_loc", Local_NTask * sizeof(int));
209 
210  size_t *list = (size_t *)Mem.mymalloc("list", Local_NTask * sizeof(size_t));
211  size_t *range_len_list = (size_t *)Mem.mymalloc("range_len_list", Local_NTask * sizeof(long long));
212  T median_element;
213  T *median_element_list = (T *)Mem.mymalloc("median_element_list", Local_NTask * size);
214  size_t *tie_breaking_rank_list = (size_t *)Mem.mymalloc("tie_breaking_rank_list", Local_NTask * sizeof(size_t));
215  int *index_list = (int *)Mem.mymalloc("index_list", Local_NTask * sizeof(int));
216  int *max_loc_list = (int *)Mem.mymalloc("max_loc_list", Local_NTask * sizeof(int));
217  size_t *source_range_len_list = (size_t *)Mem.mymalloc("source_range_len_list", Local_NTask * sizeof(long long));
218  size_t *source_tie_breaking_rank_list = (size_t *)Mem.mymalloc("source_tie_breaking_rank_list", Local_NTask * sizeof(long long));
219  T *source_median_element_list = (T *)Mem.mymalloc("source_median_element_list", Local_NTask * size);
220  T new_element_guess;
221 
222  for(int i = 0; i < Local_NTask - 1; i++)
223  {
224  desired_glob_rank[i] = noffs[i + 1];
225  current_glob_rank[i] = 0;
226  range_left[i] = 0; /* first element that it can be */
227  range_right[i] = nmemb; /* first element that it can not be */
228  }
229 
230  /* now we determine the first split element guess, which is the same for all divisions in the first iteration */
231 
232  /* find the median of each processor, and then take the median among those values.
233  * This should work reasonably well even for extremely skewed distributions
234  */
235  long long range_len = range_right[0] - range_left[0];
236 
237  if(range_len >= 1)
238  {
239  long long mid = (range_left[0] + range_right[0]) / 2;
240  median_element = begin[mid];
241  tie_breaking_rank = mid + noffs[Local_ThisTask];
242  }
243 
244  MPI_Gather(&range_len, sizeof(long long), MPI_BYTE, range_len_list, sizeof(long long), MPI_BYTE, 0, MPI_CommLocal);
245  MPI_Gather(&median_element, size, MPI_BYTE, median_element_list, size, MPI_BYTE, 0, MPI_CommLocal);
246  MPI_Gather(&tie_breaking_rank, sizeof(size_t), MPI_BYTE, tie_breaking_rank_list, sizeof(size_t), MPI_BYTE, 0, MPI_CommLocal);
247 
248  if(Local_ThisTask == 0)
249  {
250  for(int j = 0; j < Local_NTask; j++)
251  max_loc_list[j] = j;
252 
253  /* eliminate the elements that are undefined because the corresponding CPU has zero range left */
254  int nleft = Local_NTask;
255 
256  for(int j = 0; j < nleft; j++)
257  {
258  if(range_len_list[j] < 1)
259  {
260  range_len_list[j] = range_len_list[nleft - 1];
261  if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
262  {
263  median_element_list[j] = median_element_list[nleft - 1];
264  tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
265  max_loc_list[j] = max_loc_list[nleft - 1];
266  }
267 
268  nleft--;
269  j--;
270  }
271  }
272 
273  /* do a serial sort of the remaining elements (indirectly, so that we have the order of tie breaking list as well) */
274  buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
275 
276  /* now select the median of the medians */
277  int mid = nleft / 2;
278  element_guess[0] = median_element_list[index_list[mid]];
279  element_tie_breaking_rank[0] = tie_breaking_rank_list[index_list[mid]];
280  max_loc[0] = max_loc_list[index_list[mid]];
281  }
282 
283  MPI_Bcast(element_guess, size, MPI_BYTE, 0, MPI_CommLocal);
284  MPI_Bcast(&element_tie_breaking_rank[0], sizeof(size_t), MPI_BYTE, 0, MPI_CommLocal);
285  MPI_Bcast(&max_loc[0], 1, MPI_INT, 0, MPI_CommLocal);
286 
287  for(int i = 1; i < Local_NTask - 1; i++)
288  {
289  element_guess[i] = element_guess[0];
290  element_tie_breaking_rank[i] = element_tie_breaking_rank[0];
291  max_loc[i] = max_loc[0];
292  }
293 
294  int iter = 0;
295 
296  do
297  {
298  for(int i = 0; i < Local_NTask - 1; i++)
299  {
300  if(current_glob_rank[i] != desired_glob_rank[i])
301  {
302  get_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
303  range_right[i], &current_loc_rank[i], comp);
304 
305 #ifdef CHECK_LOCAL_RANK
306  check_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
307  range_right[i], current_loc_rank[i], comp);
308 #endif
309  }
310  }
311 
312  /* now compute the global ranks by summing the local ranks */
313  /* Note: the last element in current_loc_rank is not defined. It will be summed by the last processor, and stored in the last
314  * element of current_glob_rank */
315  MPI_Alltoall(current_loc_rank, sizeof(size_t), MPI_BYTE, list, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
316  rank = 0;
317  for(int j = 0; j < Local_NTask; j++)
318  rank += list[j];
319  MPI_Allgather(&rank, sizeof(size_t), MPI_BYTE, current_glob_rank, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
320 
321  ranks_not_found = 0;
322  for(int i = 0; i < Local_NTask - 1; i++)
323  {
324  if(current_glob_rank[i] != desired_glob_rank[i]) /* here we're not yet done */
325  {
326  ranks_not_found++;
327 
328  if(current_glob_rank[i] < desired_glob_rank[i])
329  {
330  range_left[i] = current_loc_rank[i];
331 
332  if(Local_ThisTask == max_loc[i])
333  range_left[i]++;
334  }
335 
336  if(current_glob_rank[i] > desired_glob_rank[i])
337  range_right[i] = current_loc_rank[i];
338  }
339  }
340 
341  /* now we need to determine new element guesses */
342  for(int i = 0; i < Local_NTask - 1; i++)
343  {
344  if(current_glob_rank[i] != desired_glob_rank[i]) /* here we're not yet done */
345  {
346  /* find the median of each processor, and then take the median among those values.
347  * This should work reasonably well even for extremely skewed distributions
348  */
349  source_range_len_list[i] = range_right[i] - range_left[i];
350 
351  if(source_range_len_list[i] >= 1)
352  {
353  long long middle = (range_left[i] + range_right[i]) / 2;
354  source_median_element_list[i] = begin[middle];
355  source_tie_breaking_rank_list[i] = middle + noffs[Local_ThisTask];
356  }
357  }
358  }
359 
360  MPI_Alltoall(source_range_len_list, sizeof(long long), MPI_BYTE, range_len_list, sizeof(long long), MPI_BYTE, MPI_CommLocal);
361  MPI_Alltoall(source_median_element_list, size, MPI_BYTE, median_element_list, size, MPI_BYTE, MPI_CommLocal);
362  MPI_Alltoall(source_tie_breaking_rank_list, sizeof(size_t), MPI_BYTE, tie_breaking_rank_list, sizeof(size_t), MPI_BYTE,
363  MPI_CommLocal);
364 
365  if(Local_ThisTask < Local_NTask - 1)
366  {
367  if(current_glob_rank[Local_ThisTask] !=
368  desired_glob_rank[Local_ThisTask]) /* in this case we're not yet done for this split point */
369  {
370  for(int j = 0; j < Local_NTask; j++)
371  max_loc_list[j] = j;
372 
373  /* eliminate the elements that are undefined because the corresponding CPU has zero range left */
374  int nleft = Local_NTask;
375 
376  for(int j = 0; j < nleft; j++)
377  {
378  if(range_len_list[j] < 1)
379  {
380  range_len_list[j] = range_len_list[nleft - 1];
381  if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
382  {
383  median_element_list[j] = median_element_list[nleft - 1];
384  tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
385  max_loc_list[j] = max_loc_list[nleft - 1];
386  }
387 
388  nleft--;
389  j--;
390  }
391  }
392 
393  if((iter & 1))
394  {
395  size_t max_range = 0, maxj = 0;
396 
397  for(int j = 0; j < nleft; j++)
398  if(range_len_list[j] > max_range)
399  {
400  max_range = range_len_list[j];
401  maxj = j;
402  }
403 
404  /* now select the median element from the task which has the largest range */
405  new_element_guess = median_element_list[maxj];
406  new_tie_breaking_rank = tie_breaking_rank_list[maxj];
407  new_max_loc = max_loc_list[maxj];
408  }
409  else
410  {
411  /* do a serial sort of the remaining elements (indirectly, so that we have the order of tie breaking list as
412  * well) */
413  buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
414 
415  /* now select the median of the medians */
416  int mid = nleft / 2;
417  new_element_guess = median_element_list[index_list[mid]];
418  new_tie_breaking_rank = tie_breaking_rank_list[index_list[mid]];
419  new_max_loc = max_loc_list[index_list[mid]];
420  }
421  }
422  else
423  {
424  /* in order to preserve existing guesses */
425  new_element_guess = element_guess[Local_ThisTask];
426  new_tie_breaking_rank = element_tie_breaking_rank[Local_ThisTask];
427  new_max_loc = max_loc[Local_ThisTask];
428  }
429  }
430 
431  MPI_Allgather(&new_element_guess, size, MPI_BYTE, element_guess, size, MPI_BYTE, MPI_CommLocal);
432  MPI_Allgather(&new_tie_breaking_rank, sizeof(size_t), MPI_BYTE, element_tie_breaking_rank, sizeof(size_t), MPI_BYTE,
433  MPI_CommLocal);
434  MPI_Allgather(&new_max_loc, 1, MPI_INT, max_loc, 1, MPI_INT, MPI_CommLocal);
435 
436  iter++;
437 
438  if(iter > (MAX_ITER_PARALLEL_SORT - 100) && Local_ThisTask == 0)
439  {
440  printf("PSORT: iter=%d: ranks_not_found=%d Local_NTask=%d\n", iter, ranks_not_found, Local_NTask);
441  myflush(stdout);
442  if(iter > MAX_ITER_PARALLEL_SORT)
443  Terminate("can't find the split points. That's odd");
444  }
445  }
446  while(ranks_not_found);
447 
448  Mem.myfree(source_median_element_list);
449  Mem.myfree(source_tie_breaking_rank_list);
450  Mem.myfree(source_range_len_list);
451  Mem.myfree(max_loc_list);
452  Mem.myfree(index_list);
453  Mem.myfree(tie_breaking_rank_list);
454  Mem.myfree(median_element_list);
455 
456  /* At this point we have found all the elements corresponding to the desired split points */
457  /* we can now go ahead and determine how many elements of the local CPU have to go to each other CPU */
458 
459  if(nmemb * size > (1LL << 31))
460  Terminate("currently, local data must be smaller than 2 GB");
461  /* note: to restrict this limitation, the send/recv count arrays have to made 64-bit,
462  * and the MPI data exchange though MPI_Alltoall has to be modified such that buffers > 2 GB become possible
463  */
464 
465  int *send_count = (int *)Mem.mymalloc("send_count", Local_NTask * sizeof(int));
466  int *recv_count = (int *)Mem.mymalloc("recv_count", Local_NTask * sizeof(int));
467  int *send_offset = (int *)Mem.mymalloc("send_offset", Local_NTask * sizeof(int));
468  int *recv_offset = (int *)Mem.mymalloc("recv_offset", Local_NTask * sizeof(int));
469 
470  for(int i = 0; i < Local_NTask; i++)
471  send_count[i] = 0;
472 
473  int target = 0;
474 
475  for(size_t i = 0; i < nmemb; i++)
476  {
477  while(target < Local_NTask - 1)
478  {
479  int cmp = comp(begin[i], element_guess[target]) ? -1 : (comp(element_guess[target], begin[i]) ? +1 : 0);
480  if(cmp == 0)
481  {
482  if(i + noffs[Local_ThisTask] < element_tie_breaking_rank[target])
483  cmp = -1;
484  else if(i + noffs[Local_ThisTask] > element_tie_breaking_rank[target])
485  cmp = +1;
486  }
487  if(cmp >= 0)
488  target++;
489  else
490  break;
491  }
492  send_count[target]++;
493  }
494 
495  MPI_Alltoall(send_count, 1, MPI_INT, recv_count, 1, MPI_INT, MPI_CommLocal);
496 
497  size_t nimport = 0;
498 
499  recv_offset[0] = 0;
500  send_offset[0] = 0;
501  for(int j = 0; j < Local_NTask; j++)
502  {
503  nimport += recv_count[j];
504 
505  if(j > 0)
506  {
507  send_offset[j] = send_offset[j - 1] + send_count[j - 1];
508  recv_offset[j] = recv_offset[j - 1] + recv_count[j - 1];
509  }
510  }
511 
512  if(nimport != nmemb)
513  Terminate("nimport=%lld != nmemb=%lld", (long long)nimport, (long long)nmemb);
514 
515  for(int j = 0; j < Local_NTask; j++)
516  {
517  send_count[j] *= size;
518  recv_count[j] *= size;
519 
520  send_offset[j] *= size;
521  recv_offset[j] *= size;
522  }
523 
524  T *basetmp = (T *)Mem.mymalloc("basetmp", nmemb * size);
525 
526  /* exchange the data */
527  MPI_Alltoallv(begin, send_count, send_offset, MPI_BYTE, basetmp, recv_count, recv_offset, MPI_BYTE, MPI_CommLocal);
528 
529  memcpy(static_cast<void *>(begin), static_cast<void *>(basetmp), nmemb * size);
530  Mem.myfree(basetmp);
531 
532  mycxxsort(begin, begin + nmemb, comp);
533 
534  Mem.myfree(recv_offset);
535  Mem.myfree(send_offset);
536  Mem.myfree(recv_count);
537  Mem.myfree(send_count);
538 
539  Mem.myfree(range_len_list);
540  Mem.myfree(list);
541  Mem.myfree(max_loc);
542  Mem.myfree(range_right);
543  Mem.myfree(range_left);
544  Mem.myfree(current_loc_rank);
545  Mem.myfree(current_glob_rank);
546  Mem.myfree(desired_glob_rank);
547  Mem.myfree(element_tie_breaking_rank);
548  Mem.myfree(element_guess);
549  Mem.myfree(noffs);
550  Mem.myfree(nlist);
551  }
552 
553  MPI_Comm_free(&MPI_CommLocal);
554 
555  double tb = Logs.second();
556  return Logs.timediff(ta, tb);
557 }
558 
559 #endif
bool operator()(std::size_t a, std::size_t b) const
Definition: parallel_sort.h:30
IdxComp__(It begin_, Comp comp_)
Definition: parallel_sort.h:29
double timediff(double t0, double t1)
Definition: logs.cc:488
double second(void)
Definition: logs.cc:471
various sort routines
double mycxxsort(T *begin, T *end, Tcomp comp)
Definition: cxxsort.h:39
logs Logs
Definition: main.cc:43
#define Terminate(...)
Definition: macros.h:19
memory Mem
Definition: main.cc:44
void get_local_rank(const T &element, std::size_t tie_breaking_rank, const T *base, size_t nmemb, size_t noffs_thistask, long long left, long long right, size_t *loc, Comp comp)
Definition: parallel_sort.h:47
double mycxxsort_parallel(T *begin, T *end, Comp comp, MPI_Comm comm)
void buildIndex(It begin, It end, T2 *idx, Comp comp)
Definition: parallel_sort.h:37
void myflush(FILE *fstream)
Definition: system.cc:125