#! /usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright (C) 2012-2014 by László Nagy
# This file is part of Bear.
#
# Bear is a tool to generate compilation database for clang tooling.
#
# Bear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Bear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
""" This module is responsible to capture the compiler invocation of any
build process. The result of that should be a compilation database.

This implementation is using the LD_PRELOAD or DYLD_INSERT_LIBRARIES
mechanisms provided by the dynamic linker. The related library is implemented
in C language and can be found under 'libear' directory.

The 'libear' library is capturing all child process creation and logging the
relevant information about it into separate files in a specified directory.
The input of the library is therefore the output directory which is passed
as an environment variable.

This module implements the build command execution with the 'libear' library
and the post-processing of the output files, which will condensates into a
(might be empty) compilation database. """

import argparse
import subprocess
import json
import sys
import os
import os.path
import re
import shlex
import itertools
import tempfile
import shutil
import contextlib
import logging


def main():
    try:
        parser = create_parser()
        args = parser.parse_args()

        level = logging.WARNING - min(logging.WARNING, (10 * args.verbose))
        logging.basicConfig(format='bear: %(message)s', level=level)
        logging.debug(args)

        if not args.build:
            parser.print_help()
            return 0

        return capture(args)
    except KeyboardInterrupt:
        return 1
    except Exception:
        logging.exception("Something unexpected had happened.")
        return 127


def capture(args):
    """ The entry point of build command interception. """

    @contextlib.contextmanager
    def temporary_directory(**kwargs):
        name = tempfile.mkdtemp(**kwargs)
        try:
            yield name
        finally:
            shutil.rmtree(name)

    def post_processing(commands):
        # run post processing only if that was requested
        if not args.raw_entries:
            # create entries from the current run
            current = itertools.chain.from_iterable(
                # creates a sequence of entry generators from an exec,
                # but filter out non compiler calls before.
                (format_entry(x) for x in commands if compiler_call(x)))
            # read entries from previous run
            if args.append and os.path.exists(args.cdb):
                with open(args.cdb) as handle:
                    previous = iter(json.load(handle))
            else:
                previous = iter([])
            # filter out duplicate entries from both
            duplicate = duplicate_check(entry_hash)
            return (entry
                    for entry in itertools.chain(previous, current)
                    if os.path.exists(entry['file']) and not duplicate(entry))
        return commands

    with temporary_directory(prefix='bear-', dir=tempdir()) as tmpdir:
        # run the build command
        environment = setup_environment(tmpdir, args.libear)
        logging.info('run build in environment: %s', environment)
        exit_code = subprocess.call(args.build, env=environment)
        logging.info('build finished with exit code: %s', exit_code)
        # read the intercepted exec calls
        commands = (parse_exec_trace(os.path.join(tmpdir, filename))
                    for filename in sorted(os.listdir(tmpdir)))
        # do post processing
        entries = post_processing(itertools.chain.from_iterable(commands))
        # dump the compilation database
        with open(args.cdb, 'w+') as handle:
            json.dump(list(entries), handle, sort_keys=True, indent=4)
        return exit_code


def setup_environment(destination, ear_library_path):
    """ Sets up the environment for the build command.

    It sets the required environment variables and execute the given command.
    The exec calls will be logged by the 'libear' preloaded library or by the
    'wrapper' programs. """

    environment = dict(os.environ)
    environment.update({'BEAR_OUTPUT': destination})

    if sys.platform == 'darwin':
        environment.update({
            'DYLD_INSERT_LIBRARIES': ear_library_path,
            'DYLD_FORCE_FLAT_NAMESPACE': '1'
        })
    else:
        environment.update({'LD_PRELOAD': ear_library_path})

    return environment


def parse_exec_trace(filename):
    """ Parse the file generated by the 'libear' preloaded library.

    Given filename points to a file which contains the basic report
    generated by the interception library or wrapper command. A single
    report file _might_ contain multiple process creation info. """

    GS = chr(0x1d)
    RS = chr(0x1e)
    US = chr(0x1f)

    logging.debug('parse exec trace file: %s', filename)
    with open(filename, 'r') as handler:
        content = handler.read()
        for group in filter(bool, content.split(GS)):
            records = group.split(RS)
            yield {
                'pid': records[0],
                'ppid': records[1],
                'function': records[2],
                'directory': records[3],
                'command': records[4].split(US)[:-1]
            }


def format_entry(entry):
    """ Generate the desired fields for compilation database entries. """

    def abspath(name):
        """ Create normalized absolute path from input filename. """
        fullname = name if os.path.isabs(
            name) else os.path.join(entry['directory'], name)
        return os.path.normpath(fullname)

    logging.debug('format this command: %s', entry['command'])
    atoms = classify_parameters(entry['command'], abspath)
    if atoms['action'] <= Action.Compile:
        for source in atoms['files']:
            compiler = 'c++' if atoms['c++'] else 'cc'
            command = [compiler, '-c'] + atoms['compile_options'] + [source]
            logging.debug('formated as: %s', command)
            yield {
                'directory': entry['directory'],
                'command': shell_escape(command),
                'file': abspath(source)
            }


def shell_escape(command):
    """ Takes a command as list and returns a string. """

    def needs_quote(word):
        """ Returns true if worduments needs to be protected by quotes.

        Previous implementation was shlex.split method, but that's not good
        for this job. Currently is running through the string with a basic
        state checking. """

        reserved = {' ', '$', '%', '&', '(', ')', '[', ']', '{', '}', '*', '|',
                    '<', '>', '@', '?', '!'}
        state = 0
        for current in word:
            if state == 0 and current in reserved:
                return True
            elif state == 0 and current == '\\':
                state = 1
            elif state == 1 and current in reserved | {'\\'}:
                state = 0
            elif state == 0 and current == '"':
                state = 2
            elif state == 2 and current == '"':
                state = 0
            elif state == 0 and current == "'":
                state = 3
            elif state == 3 and current == "'":
                state = 0
        return state != 0

    def escape(word):
        """ Do protect argument if that's needed. """

        table = {'\\': '\\\\', '"': '\\"'}
        escaped = ''.join([table.get(c, c) for c in word])

        return '"' + escaped + '"' if needs_quote(word) else escaped

    return " ".join([escape(arg) for arg in command])


def is_source(filename):
    """ A predicate to decide the filename is a source file or not. """

    accepted = {
        '.c', '.cc', '.cp', '.cpp', '.cxx', '.c++', '.m', '.mm', '.i', '.ii',
        '.mii'
    }
    __, ext = os.path.splitext(filename)
    return ext.lower() in accepted


def compiler_call(entry):
    """ A predicate to decide the entry is a compiler call or not. """

    patterns = [
        re.compile(r'^([^/]*/)*c(c|\+\+)$'),
        re.compile(r'^([^/]*/)*([^-]*-)*[mg](cc|\+\+)(-\d+(\.\d+){0,2})?$'),
        re.compile(r'^([^/]*/)*([^-]*-)*clang(\+\+)?(-\d+(\.\d+){0,2})?$'),
        re.compile(r'^([^/]*/)*llvm-g(cc|\+\+)$'),
    ]
    executable = entry['command'][0]
    return any((pattern.match(executable) for pattern in patterns))


def entry_hash(entry):
    """ Implement unique hash method for compilation database entries. """

    # For faster lookup in set filename is reverted
    filename = entry['file'][::-1]
    # For faster lookup in set directory is reverted
    directory = entry['directory'][::-1]
    # On OS X the 'cc' and 'c++' compilers are wrappers for
    # 'clang' therefore both call would be logged. To avoid
    # this the hash does not contain the first word of the
    # command.
    command = ' '.join(shlex.split(entry['command'])[1:])

    return '<>'.join([filename, directory, command])


def duplicate_check(method):
    """ Predicate to detect duplicated entries.

    Unique hash method can be use to detect duplicates. Entries are
    represented as dictionaries, which has no default hash method.
    This implementation uses a set datatype to store the unique hash values.

    This method returns a method which can detect the duplicate values. """

    def predicate(entry):
        hash_value = predicate.unique(entry)
        if hash_value not in predicate.state:
            predicate.state.add(hash_value)
            return False
        return True

    predicate.unique = method
    predicate.state = set()
    return predicate


def tempdir():
    """ Return the default temorary directory. """

    from os import getenv
    return getenv('TMPDIR', getenv('TEMP', getenv('TMP', '/tmp')))


def create_parser():
    """ Parser factory method. """

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--version',
        action='version',
        version='%(prog)s 2.1.5')
    parser.add_argument(
        '--verbose', '-v',
        action='count',
        default=0,
        help="""enable verbose output from '%(prog)s'. A second '-v' increases
                verbosity.""")
    parser.add_argument(
        '--cdb', '-o',
        metavar='<file>',
        default="compile_commands.json",
        help="""The JSON compilation database.""")
    parser.add_argument(
        '--append', '-a',
        action='store_true',
        help="""appends new entries to existing compilation database.""")

    testing = parser.add_argument_group('advanced options')
    testing.add_argument(
        '--disable-filter', '-n',
        dest='raw_entries',
        action='store_true',
        help="""disable filter, unformated output.""")
    testing.add_argument(
        '--libear', '-l',
        dest='libear',
        default="/usr/lib/x86_64-linux-gnu/bear/libear.so",
        action='store',
        help="""specify libear file location.""")

    parser.add_argument(
        dest='build',
        nargs=argparse.REMAINDER,
        help="""command to run.""")

    return parser


class Action(object):
    """ Enumeration class for compiler action. """

    Link, Compile, Ignored = range(3)


def classify_parameters(command, abspath):
    """ Parses the command line arguments of the given invocation. """

    def take(values, key, iterator):
        current = values.get(key, [])
        values.update({key: current + [iterator]})

    def action(values, value):
        current = values.get('action', value)
        values.update({'action': max(current, value)})

    state = {
        'action': Action.Link,
        'files': [],
        'compile_options': [],
        'c++': cplusplus_compiler(command[0])
    }

    args = iter(command[1:])
    for arg in args:
        # compiler action parameters are the most important ones...
        if arg == '-c':
            action(state, Action.Compile)
        elif arg in {'-E', '-S', '-cc1', '-M', '-MM', '-###'}:
            action(state, Action.Ignored)
        # some preprocessor parameters are ignored...
        elif arg in {'-MD', '-MMD', '-MG', '-MP'}:
            pass
        elif arg in {'-MF', '-MT', '-MQ'}:
            next(args)
        # linker options are ignored...
        elif arg in {'-static', '-shared', '-s', '-rdynamic'}:
            pass
        elif re.match(r'^-[lL].+', arg):
            pass
        elif arg in {'-l', '-L', '-u', '-z', '-T', '-Xlinker'}:
            next(args)
        # parameters which looks source file are taken...
        elif re.match(r'^[^-].+', arg) and is_source(arg):
            take(state, 'files', abspath(arg))
        # and consider everything else as compile option.
        else:
            # make include paths absolute
            match = re.match(r'^(-I)(.+)', arg)
            if match:
                arg = match.group(1) + abspath(match.group(2))
            take(state, 'compile_options', arg)

    return state


def cplusplus_compiler(name):
    """ Returns true when the compiler name refer to a C++ compiler. """

    match = re.match(r'^([^/]*/)*(\w*-)*(\w+\+\+)(-(\d+(\.\d+){0,3}))?$', name)
    return False if match is None else True


if __name__ == "__main__":
    sys.exit(main())
