Indicate that lttng-ust 2.13+ is required
[deliverable/lttng-ust-mpi.git] / lttng-auto-mpi-wrappers
CommitLineData
66e983e2
OD
1#!/usr/bin/env python3
2#
3# SPDX-License-Identifier: MIT
4#
e8418583 5# SPDX-FileCopyrightText: 2023 EfficiOS, Inc.
66e983e2
OD
6#
7# Author: Olivier Dion <odion@efficios.com>
8#
9# Auto-generate lttng-ust tracepoints for OpenMPI.
10#
11# Require: python-clang (libclang)
12
13import argparse
845aa3e2 14import os
66e983e2
OD
15import re
16
17from string import Template
18
19import clang.cindex
20
21def list_function_declarations(root):
22 return [ child
23 for child in root.get_children()
24 if child.kind == clang.cindex.CursorKind.FUNCTION_DECL ]
25
26def parse_header(header_file):
27 return clang.cindex.Index.create().parse(header_file).cursor
28
29def list_functions(root):
30 return [
31 fn
32 for fn in list_function_declarations(root)
33 if fn.spelling.startswith("MPI_") and fn.spelling
34 ]
35
36def exact_definition(arg):
37 m = re.search(r'(\[[0-9]*\])+', arg.type.spelling)
38 if m:
39 return f"{arg.type.spelling[:m.start(0)]} {arg.spelling}{m.group(0)}"
40 else:
41 return f"{arg.type.spelling} {arg.spelling}"
42
43forbiden_list = {
44 "MPI_Pcontrol"
45}
46
47extra_works = {
48 "MPI_Init": """
49 if (MPI_SUCCESS == ret) {
50 int (*mpi_comm_rank)(MPI_Comm, int *rank);
51 MPI_Comm mpi_comm_world;
52#ifdef CRAY_MPICH_VERSION
53 mpi_comm_world = MPI_COMM_WORLD;
54#else
55 mpi_comm_world = *(void**)resolve_or_die("ompi_mpi_comm_world_addr");
56#endif
57 mpi_comm_rank = resolve_or_die("PMPI_Comm_rank");
58 mpi_comm_rank(mpi_comm_world, &mpi_rank);
59 mpi_provider.priv = lttng_ust_context_provider_register(&mpi_provider);
60 }
61""",
62 "MPI_Finalize": """
63 if (mpi_provider.priv) {
64 lttng_ust_context_provider_unregister(mpi_provider.priv);
65 }
66""",
67}
68
69def main():
70
845aa3e2
KS
71 if os.getenv("LTTNG_UST_MPI_CLANG_LIBRARY_FILE", None) is not None:
72 clang.cindex.Config.set_library_file(os.getenv("LTTNG_UST_MPI_CLANG_LIBRARY_FILE"))
73
66e983e2
OD
74 parser = argparse.ArgumentParser(prog="lttng-ust-auto-mpi")
75
76 parser.add_argument("api",
77 help="MPI API header")
78
79 parser.add_argument("wrappers",
80 help="Path to MPI wrappers")
81
82 args = parser.parse_args()
83
84 fn_tpl = Template("""
85${ret_type} ${fn_name}(${fn_arguments})
86{
87 ${ret_type} ret;
88 {
89 static ${ret_type}(*real_fn)(${fn_arguments}) = NULL;
90 if (unlikely(NULL == __atomic_load_n(&real_fn, __ATOMIC_RELAXED))) {
91 void *result = resolve_or_die("P${fn_name}");
92 __atomic_store_n(&real_fn, result, __ATOMIC_RELAXED);
93 }
94 LTTNG_MAKE_API_OBJECT(${fn_name}${fn_rest_argument_names});
95 ret = real_fn(${fn_pass_argument_names});
96 LTTNG_MARK_RETURN_API_OBJECT(ret);
97 }
98$extra_work
99 return ret;
100}
101""")
102
103 with open(args.wrappers, "w") as output:
104 output.write("""/* Auto-generated */
105#define _GNU_SOURCE
106
107#include <stdlib.h>
108#include <string.h>
109
110#include <dlfcn.h>
111
112#include <mpi.h>
113
114#include <lttng/ust-events.h>
115#include <lttng/ust-ringbuffer-context.h>
116
117#include "lttng/ust-context-provider.h"
118
119#include "lttng-ust-mpi-states.h"
120
121#define likely(x) __builtin_expect(!!(x), 1)
122#define unlikely(x) __builtin_expect(!!(x), 0)
123
124#define die(fmt, ...) \\
125 do { \\
126 fprintf(stderr, fmt "\\n", ##__VA_ARGS__); \\
127 exit(EXIT_FAILURE); \\
128 } while (0)
129
130static void *resolve_or_die(const char *symbol)
131{
132 void *ret = dlsym(RTLD_NEXT, symbol);
133 if (unlikely(!ret)) {
134 die("could not resolve `%s': %s", symbol, dlerror());
135 }
136 return ret;
137}
138
139static inline int streq(const char *A, const char *B)
140{
141 return 0 == strcmp(A, B);
142}
143
144static inline char *context_type(struct lttng_ust_app_context *uctx)
145{
146 char *suffix = index(uctx->ctx_name, ':');
147
148 if (likely(suffix)) {
149 suffix = &suffix[1]; /* Skip ':' */
150 }
151
152 return suffix;
153}
154
155static int mpi_rank = -1;
156
157static size_t mpi_provider_get_size(void *uctx,
158 struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
159 size_t offset)
160{
161 size_t size = 0;
162 char *type = context_type(uctx);
163
164 size += lttng_ust_ring_buffer_align(offset, lttng_ust_rb_alignof(char));
165 size += sizeof(char);
166
167 if (unlikely(!type)) {
168 goto error;
169 }
170
171 if (streq(type, "rank")) {
172 size += lttng_ust_ring_buffer_align(offset, lttng_ust_rb_alignof(int64_t));
173 size += sizeof(int64_t);
174
175 } else {
176 error:
177 /* Unknown context. */
178 (void) size;
179 }
180
181 return size;
182}
183
184static void mpi_provider_record(void *uctx,
185 struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
186 struct lttng_ust_ring_buffer_ctx *ctx,
187 struct lttng_ust_channel_buffer *lttng_chan_buf)
188{
189 int sel;
190 char sel_char;
191 char *type = context_type(uctx);
192
193 if (unlikely(!type)) {
194 goto error;
195 }
196
197 if (streq(type, "rank")) {
198 int64_t v;
199 sel = LTTNG_UST_DYNAMIC_TYPE_S64;
200 sel_char = (char) sel;
201 v = (int64_t) mpi_rank;
202 lttng_chan_buf->ops->event_write(ctx, &sel_char, sizeof(sel_char),
203 lttng_ust_rb_alignof(char));
204
205 lttng_chan_buf->ops->event_write(ctx, &v, sizeof(v), lttng_ust_rb_alignof(v));
206 } else {
207 error:
208 sel = LTTNG_UST_DYNAMIC_TYPE_NONE;
209 sel_char = (char) sel;
210 lttng_chan_buf->ops->event_write(ctx, &sel_char, sizeof(sel_char),
211 lttng_ust_rb_alignof(char));
212 }
213}
214
215static void mpi_provider_get_value(void *uctx,
216 struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
217 struct lttng_ust_ctx_value *value)
218{
219 char *type = context_type(uctx);
220
221 if (unlikely(!type)) {
222 goto error;
223 }
224
225 if (streq(type, "rank")) {
226 value->sel = LTTNG_UST_DYNAMIC_TYPE_S64;
227 value->u.s64 = (int64_t) mpi_rank;
228 } else {
229 error:
230 value->sel = LTTNG_UST_DYNAMIC_TYPE_NONE;
231 }
232}
233
234static struct lttng_ust_context_provider mpi_provider = {
235 .struct_size = sizeof(struct lttng_ust_context_provider),
236 .name = "$app.MPI",
237 .get_size = mpi_provider_get_size,
238 .record = mpi_provider_record,
239 .get_value = mpi_provider_get_value
240};
241""")
242 for fn in list_functions(parse_header(args.api)):
243
244 if fn.spelling in forbiden_list:
245 continue
246
247 args = list(fn.get_arguments())
248 fn_pass_argument_names = ", ".join([
249 f"{arg.spelling}"
250 for arg in args
251 ])
252
253 if args:
254 fn_rest_argument_names = ", " + ", ".join([
255 "(%s)%s" % (re.sub(r'\[[0-9]*\]', '*', arg.type.spelling),
256 arg.spelling)
257 for arg in args
258 ])
259 else:
260 fn_rest_argument_names=""
261
262 if fn.spelling in extra_works:
263 extra_work = extra_works[fn.spelling]
264 else:
265 extra_work = ""
266
267 output.write(fn_tpl.substitute(fn_name=fn.spelling,
268 fn_arguments=", ".join([
269 exact_definition(arg)
270 for arg in fn.get_arguments()
271 ]),
272 fn_pass_argument_names=fn_pass_argument_names,
273 fn_rest_argument_names=fn_rest_argument_names,
274 ret_type=fn.type.get_result().spelling,
275 extra_work=extra_work))
276
277if __name__ == "__main__":
278 main()
This page took 0.03652 seconds and 4 git commands to generate.