X-Git-Url: http://drtracing.org/?a=blobdiff_plain;f=src%2Fbindings%2Fpython%2Fbt2%2Fbt2%2Ftrace_collection_message_iterator.py;h=097e1c8b6e8b37deb3028183f861809ecf56c7af;hb=0530003ff4f693e0a582a0ed4a15245455398b4a;hp=3003c5f3730f8fdf9e3f4f6b7ba8a75a1e8cc9a2;hpb=5f7f0be0dae6ea6c4d191db180d160eca150b53b;p=babeltrace.git diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index 3003c5f3..097e1c8b 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -24,7 +24,6 @@ from bt2 import utils, native_bt import bt2 import itertools from bt2 import message_iterator as bt2_message_iterator -from bt2 import logging as bt2_logging from bt2 import port as bt2_port from bt2 import component as bt2_component from bt2 import value as bt2_value @@ -39,6 +38,8 @@ _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) class _BaseComponentSpec: + # Base for any component spec that can be passed to + # TraceCollectionMessageIterator. def __init__(self, params, obj, logging_level): if logging_level is not None: utils._check_log_level(logging_level) @@ -61,35 +62,72 @@ class _BaseComponentSpec: class ComponentSpec(_BaseComponentSpec): + # A component spec with a specific component class. def __init__( self, - plugin_name, - class_name, + component_class, params=None, obj=None, - logging_level=bt2_logging.LoggingLevel.NONE, + logging_level=bt2.LoggingLevel.NONE, ): if type(params) is str: params = {'inputs': [params]} super().__init__(params, obj, logging_level) - utils._check_str(plugin_name) - utils._check_str(class_name) + is_cc_object = isinstance( + component_class, + (bt2._SourceComponentClassConst, bt2._FilterComponentClassConst), + ) + is_user_cc_type = isinstance( + component_class, bt2_component._UserComponentType + ) and issubclass( + component_class, (bt2._UserSourceComponent, bt2._UserFilterComponent) + ) - self._plugin_name = plugin_name - self._class_name = class_name + if not is_cc_object and not is_user_cc_type: + raise TypeError( + "'{}' is not a source or filter component class".format( + component_class.__class__.__name__ + ) + ) - @property - def plugin_name(self): - return self._plugin_name + self._component_class = component_class @property - def class_name(self): - return self._class_name + def component_class(self): + return self._component_class + + @classmethod + def from_named_plugin_and_component_class( + cls, + plugin_name, + component_class_name, + params=None, + obj=None, + logging_level=bt2.LoggingLevel.NONE, + ): + plugin = bt2.find_plugin(plugin_name) + + if plugin is None: + raise ValueError('no such plugin: {}'.format(plugin_name)) + + if component_class_name in plugin.source_component_classes: + comp_class = plugin.source_component_classes[component_class_name] + elif component_class_name in plugin.filter_component_classes: + comp_class = plugin.filter_component_classes[component_class_name] + else: + raise KeyError( + 'source or filter component class `{}` not found in plugin `{}`'.format( + component_class_name, plugin_name + ) + ) + + return cls(comp_class, params, obj, logging_level) class AutoSourceComponentSpec(_BaseComponentSpec): + # A component spec that does automatic source discovery. _no_obj = object() def __init__(self, input, params=None, obj=_no_obj, logging_level=None): @@ -178,7 +216,7 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): params['inputs'] = comp_inputs comp_specs.append( - ComponentSpec( + ComponentSpec.from_named_plugin_and_component_class( plugin_name, class_name, params=params, @@ -220,11 +258,6 @@ def _get_ns(obj): return int(s * 1e9) -class _CompClsType: - SOURCE = 0 - FILTER = 1 - - class _TraceCollectionMessageIteratorProxySink(bt2_component._UserSinkComponent): def __init__(self, params, msg_list): assert type(msg_list) is list @@ -307,6 +340,43 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._build_graph() + def _compute_stream_intersections(self): + # Pre-compute the trimmer range to use for each port in the graph, when + # stream intersection mode is enabled. + self._stream_inter_port_to_range = {} + + for src_comp_and_spec in self._src_comps_and_specs: + # Query the port's component for the `babeltrace.trace-infos` + # object which contains the range for each stream, from which we can + # compute the intersection of the streams in each trace. + query_exec = bt2.QueryExecutor( + src_comp_and_spec.spec.component_class, + 'babeltrace.trace-infos', + src_comp_and_spec.spec.params, + ) + trace_infos = query_exec.query() + + for trace_info in trace_infos: + begin = max( + [ + stream['range-ns']['begin'] + for stream in trace_info['stream-infos'] + ] + ) + end = min( + [stream['range-ns']['end'] for stream in trace_info['stream-infos']] + ) + + # Each port associated to this trace will have this computed + # range. + for stream in trace_info['stream-infos']: + # A port name is unique within a component, but not + # necessarily across all components. Use a component + # and port name pair to make it unique across the graph. + port_name = str(stream['port-name']) + key = (src_comp_and_spec.comp.addr, port_name) + self._stream_inter_port_to_range[key] = (begin, end) + def _validate_source_component_specs(self, comp_specs): for comp_spec in comp_specs: if ( @@ -335,49 +405,9 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): return msg def _create_stream_intersection_trimmer(self, component, port): - # find the original parameters specified by the user to create - # this port's component to get the `inputs` parameter - for src_comp_and_spec in self._src_comps_and_specs: - if component == src_comp_and_spec.comp: - break - - try: - inputs = src_comp_and_spec.spec.params['inputs'] - except Exception as e: - raise ValueError( - 'all source components must be created with an "inputs" parameter in stream intersection mode' - ) from e - - params = {'inputs': inputs} - - # query the port's component for the `babeltrace.trace-info` - # object which contains the stream intersection range for each - # exposed trace - query_exec = bt2.QueryExecutor( - src_comp_and_spec.comp.cls, 'babeltrace.trace-info', params - ) - trace_info_res = query_exec.query() - begin = None - end = None - - # find the trace info for this port's trace - try: - for trace_info in trace_info_res: - for stream in trace_info['streams']: - if stream['port-name'] == port.name: - range_ns = trace_info['intersection-range-ns'] - begin = range_ns['begin'] - end = range_ns['end'] - break - except Exception: - pass - - if begin is None or end is None: - raise RuntimeError( - 'cannot find stream intersection range for port "{}"'.format(port.name) - ) - - name = 'trimmer-{}-{}'.format(src_comp_and_spec.comp.name, port.name) + key = (component.addr, port.name) + begin, end = self._stream_inter_port_to_range[key] + name = 'trimmer-{}-{}'.format(component.name, port.name) return self._create_trimmer(begin, end, name) def _create_muxer(self): @@ -421,8 +451,8 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): comp_cls = plugin.filter_component_classes['trimmer'] return self._graph.add_component(comp_cls, name, params) - def _get_unique_comp_name(self, comp_spec): - name = '{}-{}'.format(comp_spec.plugin_name, comp_spec.class_name) + def _get_unique_comp_name(self, comp_cls): + name = comp_cls.name comps_and_specs = itertools.chain( self._src_comps_and_specs, self._flt_comps_and_specs ) @@ -433,27 +463,9 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): return name - def _create_comp(self, comp_spec, comp_cls_type): - plugin = bt2.find_plugin(comp_spec.plugin_name) - - if plugin is None: - raise ValueError('no such plugin: {}'.format(comp_spec.plugin_name)) - - if comp_cls_type == _CompClsType.SOURCE: - comp_classes = plugin.source_component_classes - else: - comp_classes = plugin.filter_component_classes - - if comp_spec.class_name not in comp_classes: - cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter' - raise ValueError( - 'no such {} component class in "{}" plugin: {}'.format( - cc_type, comp_spec.plugin_name, comp_spec.class_name - ) - ) - - comp_cls = comp_classes[comp_spec.class_name] - name = self._get_unique_comp_name(comp_spec) + def _create_comp(self, comp_spec): + comp_cls = comp_spec.component_class + name = self._get_unique_comp_name(comp_cls) comp = self._graph.add_component( comp_cls, name, comp_spec.params, comp_spec.obj, comp_spec.logging_level ) @@ -486,7 +498,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): if not self._connect_ports: return - if type(port) is bt2_port._InputPort: + if type(port) is bt2_port._InputPortConst: return if component not in [comp.comp for comp in self._src_comps_and_specs]: @@ -495,8 +507,36 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._connect_src_comp_port(component, port) + def _get_greatest_operative_mip_version(self): + def append_comp_specs_descriptors(descriptors, comp_specs): + for comp_spec in comp_specs: + descriptors.append( + bt2.ComponentDescriptor( + comp_spec.component_class, comp_spec.params, comp_spec.obj + ) + ) + + descriptors = [] + append_comp_specs_descriptors(descriptors, self._src_comp_specs) + append_comp_specs_descriptors(descriptors, self._flt_comp_specs) + + if self._stream_intersection_mode: + # we also need at least one `flt.utils.trimmer` component + comp_spec = ComponentSpec.from_named_plugin_and_component_class( + 'utils', 'trimmer' + ) + append_comp_specs_descriptors(descriptors, [comp_spec]) + + mip_version = bt2.get_greatest_operative_mip_version(descriptors) + + if mip_version is None: + msg = 'failed to find an operative message interchange protocol version (components are not interoperable)' + raise RuntimeError(msg) + + return mip_version + def _build_graph(self): - self._graph = bt2.Graph() + self._graph = bt2.Graph(self._get_greatest_operative_mip_version()) self._graph.add_port_added_listener(self._graph_port_added) self._muxer_comp = self._create_muxer() @@ -511,7 +551,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # create extra filter components (chained) for comp_spec in self._flt_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.FILTER) + comp = self._create_comp(comp_spec) self._flt_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # connect the extra filter chain @@ -529,9 +569,12 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # it does not exist yet (it needs the created component to # exist). for comp_spec in self._src_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.SOURCE) + comp = self._create_comp(comp_spec) self._src_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) + if self._stream_intersection_mode: + self._compute_stream_intersections() + # Now we connect the ports which exist at this point. We allow # self._graph_port_added() to automatically connect _new_ ports. self._connect_ports = True