12 #ifndef PARALLEL_SORT_H
13 #define PARALLEL_SORT_H
17 #include "../data/mymalloc.h"
21 template <
typename It,
typename Comp>
29 IdxComp__(It begin_, Comp comp_) : begin(begin_), comp(comp_) {}
30 bool operator()(std::size_t a, std::size_t b)
const {
return comp(*(begin + a), *(begin + b)); }
36 template <
typename It,
typename T2,
typename Comp>
37 inline void buildIndex(It begin, It end, T2 *idx, Comp comp)
41 for(T2 i = 0; i < num; ++i)
46 template <
typename T,
typename Comp>
47 void get_local_rank(
const T &element, std::size_t tie_breaking_rank,
const T *base,
size_t nmemb,
size_t noffs_thistask,
48 long long left,
long long right,
size_t *loc, Comp comp)
53 if(left == 0 && right == (
int)nmemb + 1)
55 if(comp(base[nmemb - 1], element))
60 else if(comp(element, base[0]))
73 if(comp(base[right - 1], element))
74 *loc = (right - 1) + 1;
75 else if(comp(element, base[left]))
81 long long mid = ((right - 1) + left) / 2;
83 int cmp = comp(base[mid], element) ? -1 : (comp(element, base[mid]) ? +1 : 0);
86 if(mid + noffs_thistask < tie_breaking_rank)
88 else if(mid + noffs_thistask > tie_breaking_rank)
98 if((right - 1) == left)
113 if((right - 1) == left + 1)
116 Terminate(
"Can't be: -->left=%lld right=%lld\n", left, right);
129 #ifdef CHECK_LOCAL_RANK
130 template <
typename T,
typename Comp>
131 inline void check_local_rank(
const T &element,
132 size_t tie_breaking_rank,
135 size_t noffs_thistask,
136 long long left,
long long right,
137 size_t loc, Comp comp)
141 for(
size_t i = 0; i < nmemb; i++)
143 int cmp = comp(base[i], element) ? -1 : (comp(element, base[i]) ? +1 : 0);
147 if(noffs_thistask + i < tie_breaking_rank)
155 if(count != (
long long)loc)
156 Terminate(
"Inconsistency: loc=%lld count=%lld left=%lld right=%lld nmemb=%lld\n", (
long long)loc, count, left, right,
161 template <
typename T,
typename Comp>
164 const int MAX_ITER_PARALLEL_SORT = 500;
165 int ranks_not_found, Local_ThisTask, Local_NTask, Color, new_max_loc;
166 size_t tie_breaking_rank, new_tie_breaking_rank, rank;
167 MPI_Comm MPI_CommLocal;
170 size_t nmemb = end - begin;
171 size_t size =
sizeof(T);
184 MPI_Comm_rank(comm, &thistask);
186 MPI_Comm_split(comm, Color, thistask, &MPI_CommLocal);
187 MPI_Comm_rank(MPI_CommLocal, &Local_ThisTask);
188 MPI_Comm_size(MPI_CommLocal, &Local_NTask);
190 if(Local_NTask > 1 && Color == 1)
192 size_t *nlist = (
size_t *)
Mem.mymalloc(
"nlist", Local_NTask *
sizeof(
size_t));
193 size_t *noffs = (
size_t *)
Mem.mymalloc(
"noffs", Local_NTask *
sizeof(
size_t));
195 MPI_Allgather(&nmemb,
sizeof(
size_t), MPI_BYTE, nlist,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
198 for(
int i = 1; i < Local_NTask; i++)
199 noffs[i] = noffs[i - 1] + nlist[i - 1];
201 T *element_guess = (T *)
Mem.mymalloc(
"element_guess", Local_NTask * size);
202 size_t *element_tie_breaking_rank = (
size_t *)
Mem.mymalloc(
"element_tie_breaking_rank", Local_NTask *
sizeof(
size_t));
203 size_t *desired_glob_rank = (
size_t *)
Mem.mymalloc(
"desired_glob_rank", Local_NTask *
sizeof(
size_t));
204 size_t *current_glob_rank = (
size_t *)
Mem.mymalloc(
"current_glob_rank", Local_NTask *
sizeof(
size_t));
205 size_t *current_loc_rank = (
size_t *)
Mem.mymalloc(
"current_loc_rank", Local_NTask *
sizeof(
size_t));
206 long long *range_left = (
long long *)
Mem.mymalloc(
"range_left", Local_NTask *
sizeof(
long long));
207 long long *range_right = (
long long *)
Mem.mymalloc(
"range_right", Local_NTask *
sizeof(
long long));
208 int *max_loc = (
int *)
Mem.mymalloc(
"max_loc", Local_NTask *
sizeof(
int));
210 size_t *list = (
size_t *)
Mem.mymalloc(
"list", Local_NTask *
sizeof(
size_t));
211 size_t *range_len_list = (
size_t *)
Mem.mymalloc(
"range_len_list", Local_NTask *
sizeof(
long long));
213 T *median_element_list = (T *)
Mem.mymalloc(
"median_element_list", Local_NTask * size);
214 size_t *tie_breaking_rank_list = (
size_t *)
Mem.mymalloc(
"tie_breaking_rank_list", Local_NTask *
sizeof(
size_t));
215 int *index_list = (
int *)
Mem.mymalloc(
"index_list", Local_NTask *
sizeof(
int));
216 int *max_loc_list = (
int *)
Mem.mymalloc(
"max_loc_list", Local_NTask *
sizeof(
int));
217 size_t *source_range_len_list = (
size_t *)
Mem.mymalloc(
"source_range_len_list", Local_NTask *
sizeof(
long long));
218 size_t *source_tie_breaking_rank_list = (
size_t *)
Mem.mymalloc(
"source_tie_breaking_rank_list", Local_NTask *
sizeof(
long long));
219 T *source_median_element_list = (T *)
Mem.mymalloc(
"source_median_element_list", Local_NTask * size);
222 for(
int i = 0; i < Local_NTask - 1; i++)
224 desired_glob_rank[i] = noffs[i + 1];
225 current_glob_rank[i] = 0;
227 range_right[i] = nmemb;
235 long long range_len = range_right[0] - range_left[0];
239 long long mid = (range_left[0] + range_right[0]) / 2;
240 median_element = begin[mid];
241 tie_breaking_rank = mid + noffs[Local_ThisTask];
244 MPI_Gather(&range_len,
sizeof(
long long), MPI_BYTE, range_len_list,
sizeof(
long long), MPI_BYTE, 0, MPI_CommLocal);
245 MPI_Gather(&median_element, size, MPI_BYTE, median_element_list, size, MPI_BYTE, 0, MPI_CommLocal);
246 MPI_Gather(&tie_breaking_rank,
sizeof(
size_t), MPI_BYTE, tie_breaking_rank_list,
sizeof(
size_t), MPI_BYTE, 0, MPI_CommLocal);
248 if(Local_ThisTask == 0)
250 for(
int j = 0; j < Local_NTask; j++)
254 int nleft = Local_NTask;
256 for(
int j = 0; j < nleft; j++)
258 if(range_len_list[j] < 1)
260 range_len_list[j] = range_len_list[nleft - 1];
261 if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
263 median_element_list[j] = median_element_list[nleft - 1];
264 tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
265 max_loc_list[j] = max_loc_list[nleft - 1];
274 buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
278 element_guess[0] = median_element_list[index_list[mid]];
279 element_tie_breaking_rank[0] = tie_breaking_rank_list[index_list[mid]];
280 max_loc[0] = max_loc_list[index_list[mid]];
283 MPI_Bcast(element_guess, size, MPI_BYTE, 0, MPI_CommLocal);
284 MPI_Bcast(&element_tie_breaking_rank[0],
sizeof(
size_t), MPI_BYTE, 0, MPI_CommLocal);
285 MPI_Bcast(&max_loc[0], 1, MPI_INT, 0, MPI_CommLocal);
287 for(
int i = 1; i < Local_NTask - 1; i++)
289 element_guess[i] = element_guess[0];
290 element_tie_breaking_rank[i] = element_tie_breaking_rank[0];
291 max_loc[i] = max_loc[0];
298 for(
int i = 0; i < Local_NTask - 1; i++)
300 if(current_glob_rank[i] != desired_glob_rank[i])
302 get_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
303 range_right[i], ¤t_loc_rank[i], comp);
305 #ifdef CHECK_LOCAL_RANK
306 check_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
307 range_right[i], current_loc_rank[i], comp);
315 MPI_Alltoall(current_loc_rank,
sizeof(
size_t), MPI_BYTE, list,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
317 for(
int j = 0; j < Local_NTask; j++)
319 MPI_Allgather(&rank,
sizeof(
size_t), MPI_BYTE, current_glob_rank,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
322 for(
int i = 0; i < Local_NTask - 1; i++)
324 if(current_glob_rank[i] != desired_glob_rank[i])
328 if(current_glob_rank[i] < desired_glob_rank[i])
330 range_left[i] = current_loc_rank[i];
332 if(Local_ThisTask == max_loc[i])
336 if(current_glob_rank[i] > desired_glob_rank[i])
337 range_right[i] = current_loc_rank[i];
342 for(
int i = 0; i < Local_NTask - 1; i++)
344 if(current_glob_rank[i] != desired_glob_rank[i])
349 source_range_len_list[i] = range_right[i] - range_left[i];
351 if(source_range_len_list[i] >= 1)
353 long long middle = (range_left[i] + range_right[i]) / 2;
354 source_median_element_list[i] = begin[middle];
355 source_tie_breaking_rank_list[i] = middle + noffs[Local_ThisTask];
360 MPI_Alltoall(source_range_len_list,
sizeof(
long long), MPI_BYTE, range_len_list,
sizeof(
long long), MPI_BYTE, MPI_CommLocal);
361 MPI_Alltoall(source_median_element_list, size, MPI_BYTE, median_element_list, size, MPI_BYTE, MPI_CommLocal);
362 MPI_Alltoall(source_tie_breaking_rank_list,
sizeof(
size_t), MPI_BYTE, tie_breaking_rank_list,
sizeof(
size_t), MPI_BYTE,
365 if(Local_ThisTask < Local_NTask - 1)
367 if(current_glob_rank[Local_ThisTask] !=
368 desired_glob_rank[Local_ThisTask])
370 for(
int j = 0; j < Local_NTask; j++)
374 int nleft = Local_NTask;
376 for(
int j = 0; j < nleft; j++)
378 if(range_len_list[j] < 1)
380 range_len_list[j] = range_len_list[nleft - 1];
381 if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
383 median_element_list[j] = median_element_list[nleft - 1];
384 tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
385 max_loc_list[j] = max_loc_list[nleft - 1];
395 size_t max_range = 0, maxj = 0;
397 for(
int j = 0; j < nleft; j++)
398 if(range_len_list[j] > max_range)
400 max_range = range_len_list[j];
405 new_element_guess = median_element_list[maxj];
406 new_tie_breaking_rank = tie_breaking_rank_list[maxj];
407 new_max_loc = max_loc_list[maxj];
413 buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
417 new_element_guess = median_element_list[index_list[mid]];
418 new_tie_breaking_rank = tie_breaking_rank_list[index_list[mid]];
419 new_max_loc = max_loc_list[index_list[mid]];
425 new_element_guess = element_guess[Local_ThisTask];
426 new_tie_breaking_rank = element_tie_breaking_rank[Local_ThisTask];
427 new_max_loc = max_loc[Local_ThisTask];
431 MPI_Allgather(&new_element_guess, size, MPI_BYTE, element_guess, size, MPI_BYTE, MPI_CommLocal);
432 MPI_Allgather(&new_tie_breaking_rank,
sizeof(
size_t), MPI_BYTE, element_tie_breaking_rank,
sizeof(
size_t), MPI_BYTE,
434 MPI_Allgather(&new_max_loc, 1, MPI_INT, max_loc, 1, MPI_INT, MPI_CommLocal);
438 if(iter > (MAX_ITER_PARALLEL_SORT - 100) && Local_ThisTask == 0)
440 printf(
"PSORT: iter=%d: ranks_not_found=%d Local_NTask=%d\n", iter, ranks_not_found, Local_NTask);
442 if(iter > MAX_ITER_PARALLEL_SORT)
443 Terminate(
"can't find the split points. That's odd");
446 while(ranks_not_found);
448 Mem.myfree(source_median_element_list);
449 Mem.myfree(source_tie_breaking_rank_list);
450 Mem.myfree(source_range_len_list);
451 Mem.myfree(max_loc_list);
452 Mem.myfree(index_list);
453 Mem.myfree(tie_breaking_rank_list);
454 Mem.myfree(median_element_list);
459 if(nmemb * size > (1LL << 31))
460 Terminate(
"currently, local data must be smaller than 2 GB");
465 int *send_count = (
int *)
Mem.mymalloc(
"send_count", Local_NTask *
sizeof(
int));
466 int *recv_count = (
int *)
Mem.mymalloc(
"recv_count", Local_NTask *
sizeof(
int));
467 int *send_offset = (
int *)
Mem.mymalloc(
"send_offset", Local_NTask *
sizeof(
int));
468 int *recv_offset = (
int *)
Mem.mymalloc(
"recv_offset", Local_NTask *
sizeof(
int));
470 for(
int i = 0; i < Local_NTask; i++)
475 for(
size_t i = 0; i < nmemb; i++)
477 while(target < Local_NTask - 1)
479 int cmp = comp(begin[i], element_guess[target]) ? -1 : (comp(element_guess[target], begin[i]) ? +1 : 0);
482 if(i + noffs[Local_ThisTask] < element_tie_breaking_rank[target])
484 else if(i + noffs[Local_ThisTask] > element_tie_breaking_rank[target])
492 send_count[target]++;
495 MPI_Alltoall(send_count, 1, MPI_INT, recv_count, 1, MPI_INT, MPI_CommLocal);
501 for(
int j = 0; j < Local_NTask; j++)
503 nimport += recv_count[j];
507 send_offset[j] = send_offset[j - 1] + send_count[j - 1];
508 recv_offset[j] = recv_offset[j - 1] + recv_count[j - 1];
513 Terminate(
"nimport=%lld != nmemb=%lld", (
long long)nimport, (
long long)nmemb);
515 for(
int j = 0; j < Local_NTask; j++)
517 send_count[j] *= size;
518 recv_count[j] *= size;
520 send_offset[j] *= size;
521 recv_offset[j] *= size;
524 T *basetmp = (T *)
Mem.mymalloc(
"basetmp", nmemb * size);
527 MPI_Alltoallv(begin, send_count, send_offset, MPI_BYTE, basetmp, recv_count, recv_offset, MPI_BYTE, MPI_CommLocal);
529 memcpy(
static_cast<void *
>(begin),
static_cast<void *
>(basetmp), nmemb * size);
534 Mem.myfree(recv_offset);
535 Mem.myfree(send_offset);
536 Mem.myfree(recv_count);
537 Mem.myfree(send_count);
539 Mem.myfree(range_len_list);
542 Mem.myfree(range_right);
543 Mem.myfree(range_left);
544 Mem.myfree(current_loc_rank);
545 Mem.myfree(current_glob_rank);
546 Mem.myfree(desired_glob_rank);
547 Mem.myfree(element_tie_breaking_rank);
548 Mem.myfree(element_guess);
553 MPI_Comm_free(&MPI_CommLocal);
bool operator()(std::size_t a, std::size_t b) const
IdxComp__(It begin_, Comp comp_)
double timediff(double t0, double t1)
double mycxxsort(T *begin, T *end, Tcomp comp)
void get_local_rank(const T &element, std::size_t tie_breaking_rank, const T *base, size_t nmemb, size_t noffs_thistask, long long left, long long right, size_t *loc, Comp comp)
double mycxxsort_parallel(T *begin, T *end, Comp comp, MPI_Comm comm)
void buildIndex(It begin, It end, T2 *idx, Comp comp)
void myflush(FILE *fstream)