Indicate that lttng-ust 2.13+ is required
[deliverable/lttng-ust-mpi.git] / test-mpi.c
CommitLineData
66e983e2
OD
1/*
2 * SPDX-License-Identifier: MIT
3 *
e8418583 4 * SPDX-FileCopyrightText: 2023 Olivier Dion <odion@efficios.com>
66e983e2
OD
5 */
6
7#include <assert.h>
8#include <stdint.h>
9#include <stdio.h>
10#include <stdlib.h>
11
12#include <mpi.h>
13
14static uint64_t sum_of(uint64_t *values, size_t values_count)
15{
16 size_t acc = 0;
17 for (size_t k=0; k<values_count; ++k) {
18 acc += values[k];
19 }
20 return acc;
21}
22
23static void usage()
24{
25 fprintf(stderr, "Usage: test-mpi N\n");
26 exit(EXIT_FAILURE);
27}
28
29static uint64_t *allocate_values(size_t upto)
30{
31 uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * upto);
32 for (size_t k=0; k<upto; ++k) {
33 values[k] = k + 1;
34 }
35 return values;
36}
37
38static void send_values(int target, uint64_t *values,
39 size_t values_count,
40 MPI_Request *request)
41{
42 MPI_Isend(values, values_count, MPI_UINT64_T,
43 target, 0, MPI_COMM_WORLD, request);
44}
45
46static void recv_answer(int target, uint64_t *value,
47 MPI_Request *request)
48{
49 MPI_Irecv(value, 1, MPI_UINT64_T,
50 target, 0, MPI_COMM_WORLD, request);
51}
52
53static void send_answer(uint64_t value)
54{
55 MPI_Send(&value, 1, MPI_UINT64_T,
56 0, 0, MPI_COMM_WORLD);
57}
58
59static uint64_t *recv_values(size_t chunk_size)
60{
61 uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * chunk_size);
62 MPI_Recv(values, chunk_size, MPI_UINT64_T,
63 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
64 return values;
65}
66
67int main(int argc, char *argv[])
68{
69 int rank;
70 int size;
71 long long upto;
72 uint64_t *values;
73
74 if (argc < 2) {
75 usage();
76 }
77
78 upto = atoll(argv[1]);
79
80 if (upto <= 0) {
81 fprintf(stderr, "N must be greater than 0\n");
82 exit(EXIT_FAILURE);
83 }
84
85 MPI_Init(&argc, &argv);
86
87 MPI_Comm_set_errhandler(MPI_COMM_WORLD,
88 MPI_ERRORS_RETURN);
89
90 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
91 MPI_Comm_size(MPI_COMM_WORLD, &size);
92
93 size_t chunk_size;
94 size_t rest;
95 uint64_t total;
96
97 if (size > 1) {
98 chunk_size = upto / (size - 1);
99 rest = upto % (size - 1);
100 } else {
101 chunk_size = 0;
102 rest = upto;
103 }
104
105 if (rank == 0) {
106 uint64_t sums[size];
107 MPI_Request requests[size - 1];
108
109 values = allocate_values(upto);
110
111 for (int k=1; k<size; ++k) {
112 send_values(k,
113 values + (chunk_size * (k - 1)),
114 chunk_size,
115 &requests[k-1]);
116 }
117
118 sums[0] = sum_of(values + chunk_size * (size - 1),
119 rest);
120
121 MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
122
123 for (int k=1; k<size; ++k) {
124 recv_answer(k, &sums[k], &requests[k-1]);
125 }
126
127 MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
128
129 total = sum_of(sums, size);
130 } else {
131 send_answer(sum_of(recv_values(chunk_size),
132 chunk_size));
133 }
134
135 MPI_Finalize();
136
137 if (rank == 0){
138 assert(total ==
139 (((uint64_t)upto * ((uint64_t)upto + 1U)) >> 1U));
140 }
141
142 return 0;
143}
This page took 0.027343 seconds and 4 git commands to generate.