Commit | Line | Data |
---|---|---|
66e983e2 OD |
1 | /* |
2 | * SPDX-License-Identifier: MIT | |
3 | * | |
e8418583 | 4 | * SPDX-FileCopyrightText: 2023 Olivier Dion <odion@efficios.com> |
66e983e2 OD |
5 | */ |
6 | ||
7 | #include <assert.h> | |
8 | #include <stdint.h> | |
9 | #include <stdio.h> | |
10 | #include <stdlib.h> | |
11 | ||
12 | #include <mpi.h> | |
13 | ||
14 | static uint64_t sum_of(uint64_t *values, size_t values_count) | |
15 | { | |
16 | size_t acc = 0; | |
17 | for (size_t k=0; k<values_count; ++k) { | |
18 | acc += values[k]; | |
19 | } | |
20 | return acc; | |
21 | } | |
22 | ||
23 | static void usage() | |
24 | { | |
25 | fprintf(stderr, "Usage: test-mpi N\n"); | |
26 | exit(EXIT_FAILURE); | |
27 | } | |
28 | ||
29 | static uint64_t *allocate_values(size_t upto) | |
30 | { | |
31 | uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * upto); | |
32 | for (size_t k=0; k<upto; ++k) { | |
33 | values[k] = k + 1; | |
34 | } | |
35 | return values; | |
36 | } | |
37 | ||
38 | static void send_values(int target, uint64_t *values, | |
39 | size_t values_count, | |
40 | MPI_Request *request) | |
41 | { | |
42 | MPI_Isend(values, values_count, MPI_UINT64_T, | |
43 | target, 0, MPI_COMM_WORLD, request); | |
44 | } | |
45 | ||
46 | static void recv_answer(int target, uint64_t *value, | |
47 | MPI_Request *request) | |
48 | { | |
49 | MPI_Irecv(value, 1, MPI_UINT64_T, | |
50 | target, 0, MPI_COMM_WORLD, request); | |
51 | } | |
52 | ||
53 | static void send_answer(uint64_t value) | |
54 | { | |
55 | MPI_Send(&value, 1, MPI_UINT64_T, | |
56 | 0, 0, MPI_COMM_WORLD); | |
57 | } | |
58 | ||
59 | static uint64_t *recv_values(size_t chunk_size) | |
60 | { | |
61 | uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * chunk_size); | |
62 | MPI_Recv(values, chunk_size, MPI_UINT64_T, | |
63 | 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); | |
64 | return values; | |
65 | } | |
66 | ||
67 | int main(int argc, char *argv[]) | |
68 | { | |
69 | int rank; | |
70 | int size; | |
71 | long long upto; | |
72 | uint64_t *values; | |
73 | ||
74 | if (argc < 2) { | |
75 | usage(); | |
76 | } | |
77 | ||
78 | upto = atoll(argv[1]); | |
79 | ||
80 | if (upto <= 0) { | |
81 | fprintf(stderr, "N must be greater than 0\n"); | |
82 | exit(EXIT_FAILURE); | |
83 | } | |
84 | ||
85 | MPI_Init(&argc, &argv); | |
86 | ||
87 | MPI_Comm_set_errhandler(MPI_COMM_WORLD, | |
88 | MPI_ERRORS_RETURN); | |
89 | ||
90 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
91 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
92 | ||
93 | size_t chunk_size; | |
94 | size_t rest; | |
95 | uint64_t total; | |
96 | ||
97 | if (size > 1) { | |
98 | chunk_size = upto / (size - 1); | |
99 | rest = upto % (size - 1); | |
100 | } else { | |
101 | chunk_size = 0; | |
102 | rest = upto; | |
103 | } | |
104 | ||
105 | if (rank == 0) { | |
106 | uint64_t sums[size]; | |
107 | MPI_Request requests[size - 1]; | |
108 | ||
109 | values = allocate_values(upto); | |
110 | ||
111 | for (int k=1; k<size; ++k) { | |
112 | send_values(k, | |
113 | values + (chunk_size * (k - 1)), | |
114 | chunk_size, | |
115 | &requests[k-1]); | |
116 | } | |
117 | ||
118 | sums[0] = sum_of(values + chunk_size * (size - 1), | |
119 | rest); | |
120 | ||
121 | MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE); | |
122 | ||
123 | for (int k=1; k<size; ++k) { | |
124 | recv_answer(k, &sums[k], &requests[k-1]); | |
125 | } | |
126 | ||
127 | MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE); | |
128 | ||
129 | total = sum_of(sums, size); | |
130 | } else { | |
131 | send_answer(sum_of(recv_values(chunk_size), | |
132 | chunk_size)); | |
133 | } | |
134 | ||
135 | MPI_Finalize(); | |
136 | ||
137 | if (rank == 0){ | |
138 | assert(total == | |
139 | (((uint64_t)upto * ((uint64_t)upto + 1U)) >> 1U)); | |
140 | } | |
141 | ||
142 | return 0; | |
143 | } |