12 #include "gadgetconfig.h"
20 #include "../data/allvars.h"
21 #include "../data/dtypes.h"
22 #include "../data/mymalloc.h"
23 #include "../mpi_utils/mpi_utils.h"
25 #define PCHAR(a) ((char *)a)
39 MPI_Comm_size(comm, &nranks);
40 MPI_Comm_rank(comm, &rank);
42 if(method == 0 || method == 1)
43 MPI_Alltoall(sendcnt, 1, MPI_INT, recvcnt, 1, MPI_INT, comm);
46 for(
int i = 0; i < nranks; ++i)
48 recvcnt[rank] = sendcnt[rank];
50 MPI_Win_create(recvcnt, nranks *
sizeof(MPI_INT),
sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
51 MPI_Win_fence(0, win);
52 for(
int i = 1; i < nranks; ++i)
54 int tgt = (rank + i) % nranks;
56 MPI_Put(&sendcnt[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
58 MPI_Win_fence(0, win);
65 for(
int i = 0; i < nranks; ++i)
74 void myMPI_Alltoallv_new(
void *sendbuf,
int *sendcnt,
int *sdispls, MPI_Datatype sendtype,
void *recvbuf,
int *recvcnt,
int *rdispls,
75 MPI_Datatype recvtype, MPI_Comm comm,
int method)
77 int rank, nranks, itsz;
78 MPI_Comm_size(comm, &nranks);
79 MPI_Comm_rank(comm, &rank);
80 MPI_Type_size(sendtype, &itsz);
84 MPI_Alltoallv(sendbuf, sendcnt, sdispls, sendtype, recvbuf, recvcnt, rdispls, recvtype, comm);
87 if(sendtype != recvtype)
90 while(lptask < nranks)
96 memcpy(
PCHAR(recvbuf) + tsz * rdispls[rank],
PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
98 for(
int ngrp = 1; ngrp < lptask; ngrp++)
100 int otask = rank ^ ngrp;
102 if(sendcnt[otask] > 0 || recvcnt[otask] > 0)
103 MPI_Sendrecv(
PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag,
104 PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, &status);
109 if(sendtype != recvtype)
110 Terminate(
"bad MPI communication types");
112 while(lptask < nranks)
116 MPI_Request *requests = (MPI_Request *)
Mem.mymalloc(
"requests", 2 * nranks *
sizeof(MPI_Request));
119 if(recvcnt[rank] > 0)
120 memcpy(
PCHAR(recvbuf) + tsz * rdispls[rank],
PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
122 for(
int ngrp = 1; ngrp < lptask; ngrp++)
124 int otask = rank ^ ngrp;
126 if(recvcnt[otask] > 0)
127 MPI_Irecv(
PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, &requests[n_requests++]);
130 for(
int ngrp = 1; ngrp < lptask; ngrp++)
132 int otask = rank ^ ngrp;
134 if(sendcnt[otask] > 0)
135 MPI_Issend(
PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag, comm, &requests[n_requests++]);
138 MPI_Waitall(n_requests, requests, MPI_STATUSES_IGNORE);
139 Mem.myfree(requests);
141 else if(method == 10)
143 if(sendtype != recvtype)
144 Terminate(
"bad MPI communication types");
145 int *disp_at_sender = (
int *)
Mem.mymalloc(
"disp_at_sender", nranks *
sizeof(
int));
146 disp_at_sender[rank] = sdispls[rank];
148 MPI_Win_create(sdispls, nranks *
sizeof(MPI_INT),
sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
149 MPI_Win_fence(0, win);
150 for(
int i = 1; i < nranks; ++i)
152 int tgt = (rank + i) % nranks;
153 if(recvcnt[tgt] != 0)
154 MPI_Get(&disp_at_sender[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
156 MPI_Win_fence(0, win);
158 if(recvcnt[rank] > 0)
159 memcpy(
PCHAR(recvbuf) + tsz * rdispls[rank],
PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
160 MPI_Win_create(sendbuf, (sdispls[nranks - 1] + sendcnt[nranks - 1]) * tsz, tsz, MPI_INFO_NULL, comm, &win);
161 MPI_Win_fence(0, win);
162 for(
int i = 1; i < nranks; ++i)
164 int tgt = (rank + i) % nranks;
165 if(recvcnt[tgt] != 0)
166 MPI_Get(
PCHAR(recvbuf) + tsz * rdispls[tgt], recvcnt[tgt], sendtype, tgt, disp_at_sender[tgt], recvcnt[tgt], sendtype,
169 MPI_Win_fence(0, win);
171 Mem.myfree(disp_at_sender);
177 void myMPI_Alltoallv(
void *sendb,
size_t *sendcounts,
size_t *sdispls,
void *recvb,
size_t *recvcounts,
size_t *rdispls,
int len,
178 int big_flag, MPI_Comm comm)
180 char *sendbuf = (
char *)sendb;
181 char *recvbuf = (
char *)recvb;
186 MPI_Comm_size(comm, &ntask);
188 int *scount = (
int *)
Mem.mymalloc(
"scount", ntask *
sizeof(
int));
189 int *rcount = (
int *)
Mem.mymalloc(
"rcount", ntask *
sizeof(
int));
190 int *soff = (
int *)
Mem.mymalloc(
"soff", ntask *
sizeof(
int));
191 int *roff = (
int *)
Mem.mymalloc(
"roff", ntask *
sizeof(
int));
193 for(
int i = 0; i < ntask; i++)
195 scount[i] = sendcounts[i] * len;
196 rcount[i] = recvcounts[i] * len;
197 soff[i] = sdispls[i] * len;
198 roff[i] = rdispls[i] * len;
201 MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
213 int ntask, thistask, ptask;
214 MPI_Comm_size(comm, &ntask);
215 MPI_Comm_rank(comm, &thistask);
217 for(ptask = 0; ntask > (1 << ptask); ptask++)
220 for(
int ngrp = 0; ngrp < (1 << ptask); ngrp++)
222 int target = thistask ^ ngrp;
226 if(sendcounts[target] > 0 || recvcounts[target] > 0)
228 recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target,
TAG_PDATA + ngrp, comm,
235 void my_int_MPI_Alltoallv(
void *sendb,
int *sendcounts,
int *sdispls,
void *recvb,
int *recvcounts,
int *rdispls,
int len,
236 int big_flag, MPI_Comm comm)
238 char *sendbuf = (
char *)sendb;
239 char *recvbuf = (
char *)recvb;
244 MPI_Comm_size(comm, &ntask);
246 int *scount = (
int *)
Mem.mymalloc(
"scount", ntask *
sizeof(
int));
247 int *rcount = (
int *)
Mem.mymalloc(
"rcount", ntask *
sizeof(
int));
248 int *soff = (
int *)
Mem.mymalloc(
"soff", ntask *
sizeof(
int));
249 int *roff = (
int *)
Mem.mymalloc(
"roff", ntask *
sizeof(
int));
251 for(
int i = 0; i < ntask; i++)
253 scount[i] = sendcounts[i] * len;
254 rcount[i] = recvcounts[i] * len;
255 soff[i] = sdispls[i] * len;
256 roff[i] = rdispls[i] * len;
259 MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
271 int ntask, thistask, ptask;
272 MPI_Comm_size(comm, &ntask);
273 MPI_Comm_rank(comm, &thistask);
275 for(ptask = 0; ntask > (1 << ptask); ptask++)
278 for(
int ngrp = 0; ngrp < (1 << ptask); ngrp++)
280 int target = thistask ^ ngrp;
284 if(sendcounts[target] > 0 || recvcounts[target] > 0)
286 recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target,
TAG_PDATA + ngrp, comm,
int myMPI_Sendrecv(void *sendbuf, size_t sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, size_t recvcount, MPI_Datatype recvtype, int source, int recvtag, MPI_Comm comm, MPI_Status *status)
void my_int_MPI_Alltoallv(void *sendb, int *sendcounts, int *sdispls, void *recvb, int *recvcounts, int *rdispls, int len, int big_flag, MPI_Comm comm)
void myMPI_Alltoallv(void *sendb, size_t *sendcounts, size_t *sdispls, void *recvb, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
int myMPI_Alltoallv_new_prep(int *sendcnt, int *recvcnt, int *rdispls, MPI_Comm comm, int method)
void myMPI_Alltoallv_new(void *sendbuf, int *sendcnt, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnt, int *rdispls, MPI_Datatype recvtype, MPI_Comm comm, int method)