X-Git-Url: http://drtracing.org/?a=blobdiff_plain;f=src%2Fbindings%2Fpython%2Fbt2%2Fbt2%2Ftrace_collection_message_iterator.py;h=56b48da9ff8061fe4a96e70870495ec711c313a2;hb=94e72386dd8c0269f889ab41eb68c34bef8e60c2;hp=25778fccaa7fbc636fc5f8f82b31000e311b5d3c;hpb=cc81b5abd3b164b25870673c9a0f93f5c8b5461a;p=babeltrace.git diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index 25778fcc..56b48da9 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -20,10 +20,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from bt2 import utils +from bt2 import utils, native_bt import bt2 import itertools -import bt2.message_iterator +from bt2 import message_iterator as bt2_message_iterator +from bt2 import logging as bt2_logging +from bt2 import port as bt2_port +from bt2 import component as bt2_component +from bt2 import value as bt2_value +from bt2 import plugin as bt2_plugin import datetime from collections import namedtuple import numbers @@ -33,20 +38,47 @@ import numbers _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) -class ComponentSpec: - def __init__(self, plugin_name, class_name, params=None, - logging_level=bt2.logging.LoggingLevel.NONE): +class _BaseComponentSpec: + def __init__(self, params, obj, logging_level): + if logging_level is not None: + utils._check_log_level(logging_level) + + self._params = bt2.create_value(params) + self._obj = obj + self._logging_level = logging_level + + @property + def params(self): + return self._params + + @property + def obj(self): + return self._obj + + @property + def logging_level(self): + return self._logging_level + + +class ComponentSpec(_BaseComponentSpec): + def __init__( + self, + plugin_name, + class_name, + params=None, + obj=None, + logging_level=bt2_logging.LoggingLevel.NONE, + ): + if type(params) is str: + params = {'inputs': [params]} + + super().__init__(params, obj, logging_level) + utils._check_str(plugin_name) utils._check_str(class_name) - utils._check_log_level(logging_level) + self._plugin_name = plugin_name self._class_name = class_name - self._logging_level = logging_level - - if type(params) is str: - self._params = bt2.create_value({'paths': [params]}) - else: - self._params = bt2.create_value(params) @property def plugin_name(self): @@ -56,13 +88,102 @@ class ComponentSpec: def class_name(self): return self._class_name - @property - def logging_level(self): - return self._logging_level + +class AutoSourceComponentSpec(_BaseComponentSpec): + _no_obj = object() + + def __init__(self, input, params=None, obj=_no_obj, logging_level=None): + super().__init__(params, obj, logging_level) + self._input = input @property - def params(self): - return self._params + def input(self): + return self._input + + +def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): + # Transform a list of `AutoSourceComponentSpec` in a list of `ComponentSpec` + # using the automatic source discovery mechanism. + inputs = bt2.ArrayValue([spec.input for spec in auto_source_comp_specs]) + + if plugin_set is None: + plugin_set = bt2.find_plugins() + else: + utils._check_type(plugin_set, bt2_plugin._PluginSet) + + res_ptr = native_bt.bt2_auto_discover_source_components( + inputs._ptr, plugin_set._ptr + ) + + if res_ptr is None: + raise bt2._MemoryError('cannot auto discover source components') + + res = bt2_value._create_from_ptr(res_ptr) + + assert type(res) == bt2.MapValue + assert 'status' in res + + status = res['status'] + utils._handle_func_status(status, 'cannot auto-discover source components') + + comp_specs = [] + comp_specs_raw = res['results'] + assert type(comp_specs_raw) == bt2.ArrayValue + + for comp_spec_raw in comp_specs_raw: + assert type(comp_spec_raw) == bt2.ArrayValue + assert len(comp_spec_raw) == 4 + + plugin_name = comp_spec_raw[0] + assert type(plugin_name) == bt2.StringValue + plugin_name = str(plugin_name) + + class_name = comp_spec_raw[1] + assert type(class_name) == bt2.StringValue + class_name = str(class_name) + + comp_inputs = comp_spec_raw[2] + assert type(comp_inputs) == bt2.ArrayValue + + comp_orig_indices = comp_spec_raw[3] + assert type(comp_orig_indices) + + params = bt2.MapValue() + logging_level = bt2.LoggingLevel.NONE + obj = None + + # Compute `params` for this component by piling up params given to all + # AutoSourceComponentSpec objects that contributed in the instantiation + # of this component. + # + # The effective log level for a component is the last one specified + # across the AutoSourceComponentSpec that contributed in its + # instantiation. + for idx in comp_orig_indices: + orig_spec = auto_source_comp_specs[idx] + + if orig_spec.params is not None: + params.update(orig_spec.params) + + if orig_spec.logging_level is not None: + logging_level = orig_spec.logging_level + + if orig_spec.obj is not AutoSourceComponentSpec._no_obj: + obj = orig_spec.obj + + params['inputs'] = comp_inputs + + comp_specs.append( + ComponentSpec( + plugin_name, + class_name, + params=params, + obj=obj, + logging_level=logging_level, + ) + ) + + return comp_specs # datetime.datetime or integral to nanoseconds @@ -77,7 +198,9 @@ def _get_ns(obj): # s -> ns s = obj.timestamp() else: - raise TypeError('"{}" is not an integral number or a datetime.datetime object'.format(obj)) + raise TypeError( + '"{}" is not an integral number or a datetime.datetime object'.format(obj) + ) return int(s * 1e9) @@ -87,23 +210,78 @@ class _CompClsType: FILTER = 1 -class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): - def __init__(self, source_component_specs, filter_component_specs=None, - stream_intersection_mode=False, begin=None, end=None): +class _TraceCollectionMessageIteratorProxySink(bt2_component._UserSinkComponent): + def __init__(self, params, msg_list): + assert type(msg_list) is list + self._msg_list = msg_list + self._add_input_port('in') + + def _user_graph_is_configured(self): + self._msg_iter = self._create_input_port_message_iterator( + self._input_ports['in'] + ) + + def _user_consume(self): + assert self._msg_list[0] is None + self._msg_list[0] = next(self._msg_iter) + + +class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): + def __init__( + self, + source_component_specs, + filter_component_specs=None, + stream_intersection_mode=False, + begin=None, + end=None, + plugin_set=None, + ): utils._check_bool(stream_intersection_mode) self._stream_intersection_mode = stream_intersection_mode self._begin_ns = _get_ns(begin) self._end_ns = _get_ns(end) - - if type(source_component_specs) is ComponentSpec: + self._msg_list = [None] + + # If a single item is provided, convert to a list. + if type(source_component_specs) in ( + ComponentSpec, + AutoSourceComponentSpec, + str, + ): source_component_specs = [source_component_specs] + # Convert any string to an AutoSourceComponentSpec. + def str_to_auto(item): + if type(item) is str: + item = AutoSourceComponentSpec(item) + + return item + + source_component_specs = [str_to_auto(s) for s in source_component_specs] + if type(filter_component_specs) is ComponentSpec: filter_component_specs = [filter_component_specs] elif filter_component_specs is None: filter_component_specs = [] - self._src_comp_specs = source_component_specs + self._validate_source_component_specs(source_component_specs) + self._validate_filter_component_specs(filter_component_specs) + + # Pass any `ComponentSpec` instance as-is. + self._src_comp_specs = [ + spec for spec in source_component_specs if type(spec) is ComponentSpec + ] + + # Convert any `AutoSourceComponentSpec` in concrete `ComponentSpec` instances. + auto_src_comp_specs = [ + spec + for spec in source_component_specs + if type(spec) is AutoSourceComponentSpec + ] + self._src_comp_specs += _auto_discover_source_component_specs( + auto_src_comp_specs, plugin_set + ) + self._flt_comp_specs = filter_component_specs self._next_suffix = 1 self._connect_ports = False @@ -112,38 +290,58 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._src_comps_and_specs = [] self._flt_comps_and_specs = [] - self._validate_component_specs(source_component_specs) - self._validate_component_specs(filter_component_specs) self._build_graph() - def _validate_component_specs(self, comp_specs): + def _validate_source_component_specs(self, comp_specs): + for comp_spec in comp_specs: + if ( + type(comp_spec) is not ComponentSpec + and type(comp_spec) is not AutoSourceComponentSpec + ): + raise TypeError( + '"{}" object is not a ComponentSpec or AutoSourceComponentSpec'.format( + type(comp_spec) + ) + ) + + def _validate_filter_component_specs(self, comp_specs): for comp_spec in comp_specs: if type(comp_spec) is not ComponentSpec: - raise TypeError('"{}" object is not a ComponentSpec'.format(type(comp_spec))) + raise TypeError( + '"{}" object is not a ComponentSpec'.format(type(comp_spec)) + ) def __next__(self): - return next(self._msg_iter) + assert self._msg_list[0] is None + self._graph.run_once() + msg = self._msg_list[0] + assert msg is not None + self._msg_list[0] = None + return msg def _create_stream_intersection_trimmer(self, component, port): # find the original parameters specified by the user to create - # this port's component to get the `path` parameter + # this port's component to get the `inputs` parameter for src_comp_and_spec in self._src_comps_and_specs: if component == src_comp_and_spec.comp: break try: - paths = src_comp_and_spec.spec.params['paths'] + inputs = src_comp_and_spec.spec.params['inputs'] except Exception as e: - raise bt2.Error('all source components must be created with a "paths" parameter in stream intersection mode') from e - - params = {'paths': paths} - - # query the port's component for the `trace-info` object which - # contains the stream intersection range for each exposed - # trace - query_exec = bt2.QueryExecutor() - trace_info_res = query_exec.query(src_comp_and_spec.comp.cls, - 'trace-info', params) + raise ValueError( + 'all source components must be created with an "inputs" parameter in stream intersection mode' + ) from e + + params = {'inputs': inputs} + + # query the port's component for the `babeltrace.trace-info` + # object which contains the stream intersection range for each + # exposed trace + query_exec = bt2.QueryExecutor( + src_comp_and_spec.comp.cls, 'babeltrace.trace-info', params + ) + trace_info_res = query_exec.query() begin = None end = None @@ -160,7 +358,9 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): pass if begin is None or end is None: - raise bt2.Error('cannot find stream intersection range for port "{}"'.format(port.name)) + raise RuntimeError( + 'cannot find stream intersection range for port "{}"'.format(port.name) + ) name = 'trimmer-{}-{}'.format(src_comp_and_spec.comp.name, port.name) return self._create_trimmer(begin, end, name) @@ -169,10 +369,12 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): plugin = bt2.find_plugin('utils') if plugin is None: - raise bt2.Error('cannot find "utils" plugin (needed for the muxer)') + raise RuntimeError('cannot find "utils" plugin (needed for the muxer)') if 'muxer' not in plugin.filter_component_classes: - raise bt2.Error('cannot find "muxer" filter component class in "utils" plugin') + raise RuntimeError( + 'cannot find "muxer" filter component class in "utils" plugin' + ) comp_cls = plugin.filter_component_classes['muxer'] return self._graph.add_component(comp_cls, 'muxer') @@ -181,10 +383,12 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): plugin = bt2.find_plugin('utils') if plugin is None: - raise bt2.Error('cannot find "utils" plugin (needed for the trimmer)') + raise RuntimeError('cannot find "utils" plugin (needed for the trimmer)') if 'trimmer' not in plugin.filter_component_classes: - raise bt2.Error('cannot find "trimmer" filter component class in "utils" plugin') + raise RuntimeError( + 'cannot find "trimmer" filter component class in "utils" plugin' + ) params = {} @@ -203,10 +407,10 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): return self._graph.add_component(comp_cls, name, params) def _get_unique_comp_name(self, comp_spec): - name = '{}-{}'.format(comp_spec.plugin_name, - comp_spec.class_name) - comps_and_specs = itertools.chain(self._src_comps_and_specs, - self._flt_comps_and_specs) + name = '{}-{}'.format(comp_spec.plugin_name, comp_spec.class_name) + comps_and_specs = itertools.chain( + self._src_comps_and_specs, self._flt_comps_and_specs + ) if name in [comp_and_spec.comp.name for comp_and_spec in comps_and_specs]: name += '-{}'.format(self._next_suffix) @@ -218,7 +422,7 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): plugin = bt2.find_plugin(comp_spec.plugin_name) if plugin is None: - raise bt2.Error('no such plugin: {}'.format(comp_spec.plugin_name)) + raise ValueError('no such plugin: {}'.format(comp_spec.plugin_name)) if comp_cls_type == _CompClsType.SOURCE: comp_classes = plugin.source_component_classes @@ -227,14 +431,17 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): if comp_spec.class_name not in comp_classes: cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter' - raise bt2.Error('no such {} component class in "{}" plugin: {}'.format(cc_type, - comp_spec.plugin_name, - comp_spec.class_name)) + raise ValueError( + 'no such {} component class in "{}" plugin: {}'.format( + cc_type, comp_spec.plugin_name, comp_spec.class_name + ) + ) comp_cls = comp_classes[comp_spec.class_name] name = self._get_unique_comp_name(comp_spec) - comp = self._graph.add_component(comp_cls, name, comp_spec.params, - comp_spec.logging_level) + comp = self._graph.add_component( + comp_cls, name, comp_spec.params, comp_spec.obj, comp_spec.logging_level + ) return comp def _get_free_muxer_input_port(self): @@ -264,7 +471,7 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): if not self._connect_ports: return - if type(port) is bt2.port._InputPort: + if type(port) is bt2_port._InputPort: return if component not in [comp.comp for comp in self._src_comps_and_specs]: @@ -279,13 +486,13 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._muxer_comp = self._create_muxer() if self._begin_ns is not None or self._end_ns is not None: - trimmer_comp = self._create_trimmer(self._begin_ns, - self._end_ns, 'trimmer') - self._graph.connect_ports(self._muxer_comp.output_ports['out'], - trimmer_comp.input_ports['in']) - msg_iter_port = trimmer_comp.output_ports['out'] + trimmer_comp = self._create_trimmer(self._begin_ns, self._end_ns, 'trimmer') + self._graph.connect_ports( + self._muxer_comp.output_ports['out'], trimmer_comp.input_ports['in'] + ) + last_flt_out_port = trimmer_comp.output_ports['out'] else: - msg_iter_port = self._muxer_comp.output_ports['out'] + last_flt_out_port = self._muxer_comp.output_ports['out'] # create extra filter components (chained) for comp_spec in self._flt_comp_specs: @@ -296,8 +503,8 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): for comp_and_spec in self._flt_comps_and_specs: in_port = list(comp_and_spec.comp.input_ports.values())[0] out_port = list(comp_and_spec.comp.output_ports.values())[0] - self._graph.connect_ports(msg_iter_port, in_port) - msg_iter_port = out_port + self._graph.connect_ports(last_flt_out_port, in_port) + last_flt_out_port = out_port # Here we create the components, self._graph_port_added() is # called when they add ports, but the callback returns early @@ -326,5 +533,12 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._connect_src_comp_port(comp_and_spec.comp, out_port) - # create this trace collection iterator's message iterator - self._msg_iter = self._graph.create_output_port_message_iterator(msg_iter_port) + # Add the proxy sink, passing our message list to share consumed + # messages with this trace collection message iterator. + sink = self._graph.add_component( + _TraceCollectionMessageIteratorProxySink, 'proxy-sink', obj=self._msg_list + ) + sink_in_port = sink.input_ports['in'] + + # connect last filter to proxy sink + self._graph.connect_ports(last_flt_out_port, sink_in_port)