#!/usr/bin/env python3

import argparse
from datetime import datetime

import bt2

class Thread:

    def __init__(self):
        self.signals  = []
        self.pendings = []

class Signal:

    def __init__(self, msg, indent):
        self.indent    = indent
        self.beg       = msg
        self.signals   = []
        self.callbacks = []
        self.pendings  = []
        self.elapse    = -1

    def end(self, end):
        self.elapse = (end.default_clock_snapshot.ns_from_origin -
                       self.beg.default_clock_snapshot.ns_from_origin) / 1E3
        return self

    def __str__(self):
        return "{:.2f}us: {}".format(self.elapse, self.beg.event["signal_type"])

class Callback():

    def __init__(self, begin):
        self.source = "{}:{}".format(begin.event["source_filename"],
                                     begin.event["source_line"])
        self.beg    = begin
        self.elapse = -1
        self.name   = ""

    def end(self, end):
        self.elapse = (end.default_clock_snapshot.ns_from_origin -
                       self.beg.default_clock_snapshot.ns_from_origin) / 1E3
        return self

def print_timeline(thread):
    for s in thread.signals:
        prefix = "\t" * s.indent
        print(f"{prefix}{s}")
        print("{}\t{:<30s}{:<15}".format(prefix, "callback", "duration (%)"))
        for cb in s.callbacks:
            print("{}\t{:<30s}{:<.2f}".format(prefix,
                                              cb.source,
                                              100 *cb.elapse / s.elapse))
        print("")

def main(args):

    filter_tid = args.tid
    threads   = {}
    prefix      = "jami:emit_signal"

    for msg in bt2.TraceCollectionMessageIterator(args.path[0]):

        if type(msg) is bt2._EventMessageConst:

            name = msg.event.name

            if not name.startswith(prefix):
                continue

            if filter_tid > 0 and msg.event["vtid"] != filter_tid:
                continue

            tid = msg.event["vtid"]

            if tid not in threads:
                threads[tid] = Thread()

            thread = threads[tid]

            pendings = thread.pendings

            name = name[len(prefix):]

            if name == "":
                pendings.append(Signal(msg, len(pendings) + 1))

            elif name == "_end":
                thread.signals.append(pendings.pop().end(msg))

            elif name == "_begin_callback":
                signal = pendings[-1]
                signal.pendings.append(Callback(msg))

            elif name == "_end_callback":
                signal = pendings[-1]
                signal.callbacks.append(signal.pendings.pop().end(msg))

    for tid, thread in threads.items():
        print(f"Thread: {tid}")
        print_timeline(thread)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Show signals timeline.")

    parser.add_argument("path", metavar="TRACE", type=str, nargs=1)
    parser.add_argument("tid", metavar="TID", type=int, nargs='?', default=-1)

    args = parser.parse_args()

    main(args)
