#!/usr/bin/env python3
#
# Usage: gphoto-m4-sync [--diff] <dir>...
#        gphoto-m4-sync --help
#
# The gphoto-m4-sync script helps with keeping track of which files in
# which gphoto-m4 tree copy differ from the original gphoto-m4 tree.
#
# In normal operation, gphoto-m4-sync will search for gphoto-m4
# directories anywhere in the directory trees given on the command
# line and compare the gphoto-m4 tree from which gphoto-m4-sync was
# started to those other trees.
#
# When not given a --diff options, gphoto-m4-sync will print a human
# readable report on which files are different in which gphoto-m4
# tree.
#
# Options:
#
#   --diff   Print a list of 'diff' command lines to compare
#            the different files instead. Pipe into something like
#            "| sh | less" to execute.
#
#   --help   Print this help message.
#
# Exit code:
#
#   0    when no  differences have been found among the gphoto-m4 trees
#   1    when any differences have been found among the gphoto-m4 trees
#   2    any other error


########################################################################


import hashlib
import os
import sys


########################################################################


class File(object):

    def __init__(self, tree, fname):
        self.tree = tree
        self.fname = fname
        self.fpath = os.path.join(tree.top, fname)

        self.statinfo = os.stat(self.fpath)

        m = hashlib.sha1()
        m.update(open(self.fpath, 'rb').read())
        self.digest = m.hexdigest()

    def __repr__(self):
        return 'File(%s,%s)' % (repr(self.fname), repr(self.digest))

    def __str__(self):
        return '%s %s' % (self.digest, self.fname)


########################################################################


class BaseTree(object):

    def __init__(self, top):
        self.top = os.path.abspath(top)
        self._files = self.__scan_files()

    def __repr__(self):
        return '%s(%s)[%s]' % (self.__class__.__name__, self.top, self._files)

    def __iter__(self):
        return sorted(self._files).__iter__()

    def __getitem__(self, key):
        return self._files[key]

    def __scan_files(self):
        files = {}
        for dirpath, dirnames, filenames in os.walk(self.top, topdown=True):
            try: # do not descend into these directories
                dirnames.remove('.git')
            except ValueError:
                pass

            for fname in filenames:
                # Ignore a bunch of files
                if fname[-1] == '~':
                    continue
                if fname.startswith('.git'):
                    continue
                if fname in ['Makefile.in', 'Makefile']:
                    continue

                abs_fname = os.path.join(dirpath, fname)
                rel_fname = os.path.relpath(abs_fname, start=self.top)

                files[rel_fname] = File(self, rel_fname)
        return files


########################################################################


class GitTree(BaseTree):

    def __init__(self, top):
        path = os.path.join(top, '.git')
        if not os.path.exists(path):
            raise AssertionError("File or directory does not exist: %s" %
                                 repr(path))
        super(GitTree, self).__init__(top)


########################################################################


class NotGitTree(BaseTree):

    def __init__(self, top):
        path = os.path.join(top, '.git')
        if os.path.exists(path):
            raise AssertionError("File or directory does exist: %s" %
                                 repr(path))
        super(NotGitTree, self).__init__(top)


########################################################################


def scan_tree(top):
    for dirpath, dirnames, filenames in os.walk(top):
        if os.path.basename(dirpath) == 'gphoto-m4':
            if 'gp-camlibs.m4' not in filenames:
                continue
            yield (dirpath, NotGitTree(dirpath))


########################################################################


def print_help():
    skip_line = True
    skip_lines = ['#', '# ']
    for line in open(__file__, 'r'):
        if line[-1] == '\n':
            line = line[:-1]

        if line.startswith('#!'):
            continue
        elif skip_line and (line in skip_lines):
            continue
        elif skip_line and (line not in skip_lines):
            skip_line = False
        elif line == '':
            break

        if not skip_line:
            print(line[2:])


########################################################################


class ResultTable(object):

    def __init__(self):
        self.lines = {}
        self.files_with_differences = 0
        self.differences = 0

    def __setitem__(self, key, value):
        assert(key not in self.lines)
        self.lines[key] = value
        if value.file_versions > 0:
            self.files_with_differences += 1
        self.differences += value.file_versions

    def __getitem__(self, key):
        assert(self.files_with_differences != None)
        return self.lines[key]

    def items(self):
        for k in sorted(self.lines.keys()):
            v = self.lines[k]
            yield k,v

    def close(self):
        pass


########################################################################


class ResultLine(object):

    def __init__(self, fname):
        self.fname = fname

        self.__digest_map = {}
        self.__digests = {}
        self.__digest_list = None

        self.__flags = {}

        self.__fpaths = {}

    def set_digest(self, index, digest):
        self.__digest_map[index] = digest
        self.__digests[digest] = True

    def close(self, file_versions):
        self.file_versions = file_versions
        self.__digest_list = sorted(self.__digests.keys())
        assert(len(self.__digest_list) > 0)
        if file_versions == 0:
            # All files are equal, so we do not need different characters
            # to distinguish different digest values - a space will do as
            # well.
            self.__digest_map = {}

    def get_digest(self, index):
        if self.__digest_list == None:
            raise RuntimeError("You need to call ResultLine.close() before Result_Line.get_digest()")
        if index in self.__digest_map:
            dig = self.__digest_map[index]
            idx = self.__digest_list.index(dig)
            return 'abcdefghijklmnopqrstuvwxyz'[idx]
        else:
            return ' '

    def set_flag(self, index, flag, fpath):
        self.__flags[index] = flag
        self.__fpaths[index] = fpath

    def get_flag(self, index):
        return self.__flags[index]

    def get_fpath(self, index):
        return self.__fpaths[index]


########################################################################


def cmd_print_report(result_table, all_files, treelist, trees):
    # Enumerate list of trees
    print("Trees (0 is the original tree):")
    for i, tree in enumerate(treelist):
        print("  %d. %s" % (i,tree))
    print()

    # Determine maximum length of file name
    fn_maxlen = 0
    for fn in all_files:
        if len(fn) > fn_maxlen:
            fn_maxlen = len(fn)

    fmt = "    %%-%ds " % fn_maxlen
    print("File table:")

    # print table head
    print(fmt % '', end='')
    print((' {0:-^%d}' % (3*len(treelist)-1)).format('Tree'))
    print(fmt % 'file name', end='')
    for i, tree in enumerate(treelist):
        print(' %2d' % i, end='')
    print('  file diffs')
    sep_line = ('    ' +
                '-' * (fn_maxlen + 1 + 3*len(treelist) + 2 + len('file diffs')))
    print(sep_line)

    # print table body
    for fname in sorted(all_files):
        result_line = result_table[fname]
        print(fmt % fname, end='')
        print(" %s%s" % (result_line.get_flag(0),
                         result_line.get_digest(0)), end='')

        for tree_idx, tree_top in enumerate(sorted(trees.keys()), start=1):
            tree = trees[tree_top]
            print(" %s%s" % (result_line.get_flag(tree_idx),
                             result_line.get_digest(tree_idx)), end='')

        if result_line.file_versions > 0:
            print('   %3d' % result_line.file_versions)
        else:
            print('  ok')
    print(sep_line)
    print()

    print("Legend:")
    legend = [
        ('N', 'new file'),
        ('O', 'original file'),
        ('/', 'no such file'),
        ('=', 'same content as the original file'),
        ('<', 'file with different content is younger than original file'),
        ('>', 'file with different content is older than original file'),
    ]
    for ch, descr in legend:
        print("    %s  %s" % (ch, descr))
    print("    ")
    print("    Small letters identify file contents: Same letter means same content.")
    print()

    # Determine exit code
    exit_code = 0
    if result_table.differences > 0:
        exit_code = 1

    # Print summary
    print("Summary:")
    if result_table.differences > 0:
        print("  About %d difference(s) found in %d file(s)." %
              (result_table.differences, result_table.files_with_differences))
        print("  ")
        print("  Diff commands for comparing differing files can be obtained with the")
        print("  '--diff' option.")
    else:
        print("  All gphoto-m4 trees are equal.")

    # Finally exit.
    sys.exit(exit_code)


########################################################################


def print_diff_commands(diff_commands):
    print("#!/bin/sh")
    print("#")
    print("# This file has been autogenerated by %s" % __file__)
    print("#")
    print("# List of diff commands.  You can pipe these into")
    print("#     | sh | colordiff | less -r '+/comparing '")
    print("# or")
    print("#     | sh | less '+/^comparing '")
    print("# or")
    print("#     | less")
    for fname, orig_dig, other_dig, orig_fpath, other_fpath in diff_commands:
        if orig_fpath:
            orig_label = "%s (digest '%s')" % (orig_fpath, orig_dig)
        else:
            orig_fpath = '/dev/null'
            orig_label = '(no such file)'

        if other_dig:
            other_label = "%s (digest '%s')" % (other_fpath, other_dig)
        else:
            other_label = other_fpath

        print()
        print("""echo 'comparing fname %s'""" % fname)
        print("""diff -u --label "%s" %s --label "%s" %s"""
              % (orig_label, orig_fpath,
                 other_label, other_fpath))


########################################################################


def gphoto_m4_sync(dir_list, print_diffs):

    # List all files in this clone of the `gphoto-m4` repository
    orig_top = os.path.dirname(os.path.abspath(__file__))
    orig_tree = GitTree(orig_top)

    # For each `gphoto-m4` directory given on the command line, find
    # all files.
    trees = {}
    for top in dir_list:
        for dirpath, tree in scan_tree(os.path.abspath(top)):
            trees[dirpath] = tree

    if len(trees) == 0:
        print("No gphoto-m4 trees found in directories given on command line.")
        sys.exit(2)


    # Make a list of all files within all `gphoto-m4` trees
    all_files = {}
    for i in orig_tree:
        all_files[i] = True
    for tree in trees.values():
        for i in tree:
            all_files[i] = True
    all_files = sorted(all_files.keys())

    # calculate table values
    diff_params = []
    result_table = ResultTable()
    for fname in sorted(all_files):
        result_line = ResultLine(fname)
        file_diffs = 0
        if fname in orig_tree:
            result_line.set_flag(0, 'O', orig_tree[fname].fpath)
            orig_dig = orig_tree[fname].digest
            result_line.set_digest(0, orig_dig)
        else:
            result_line.set_flag(0, '/', None)
            orig_dig = None

        comp_digs = {}
        for tree_idx, tree_top in enumerate(sorted(trees.keys()), start=1):
            tree = trees[tree_top]
            if fname in tree:
                dig = tree[fname].digest
                flag = 'N'
                if orig_dig == dig:
                    flag = '='
                elif orig_dig:
                    if tree[fname].statinfo.st_mtime > orig_tree[fname].statinfo.st_mtime:
                        flag = '>'
                    elif tree[fname].statinfo.st_mtime < orig_tree[fname].statinfo.st_mtime:
                        flag = '<'
                result_line.set_digest(tree_idx, dig)
            else:
                flag = '/'

            if fname in tree:
                _fpath = tree[fname].fpath
            else:
                _fpath = None
            result_line.set_flag(tree_idx, flag, _fpath)

            if orig_dig:
                if result_line.get_flag(tree_idx) != '=':
                    file_diffs += 1
            else:
                if result_line.get_flag(tree_idx) != '/':
                    file_diffs += 1

        result_line.close(file_diffs)
        del file_diffs
        result_table[fname] = result_line
    result_table.close()

    if False:
        # Diff all files - (some comparisons are unnecessary)
        for fname in sorted(all_files):
            result_line = result_table[fname]
            orig_dig = result_line.get_digest(0)
            orig_fpath = result_line.get_fpath(0)
            for tree_idx, tree_top in enumerate(sorted(trees.keys()), start=1):
                tree = trees[tree_top]
                if result_line.get_flag(0) == 'O':
                    if result_line.get_flag(tree_idx) not in ['=', '/']:
                        diff_params.append((fname,
                                            orig_fpath, orig_dig,
                                            result_line.get_fpath(tree_idx),
                                            result_line.get_digest(tree_idx)))
                else:
                    if result_line.get_flag(tree_idx) != '/':
                        diff_params.append((fname,
                                            None, None,
                                            result_line.get_fpath(tree_idx),
                                            None))

    # Print report
    if not print_diffs:
        cmd_print_report(result_table, all_files,
                         [orig_top] + sorted(trees.keys()),
                         trees)

    # Print diffs
    if print_diffs:
        # print("# Calculate minimum set of diff commands:")
        diff_commands = []
        for fname, result_line in result_table.items():
            line_flags = []
            if result_line.file_versions > 0:
                # print("#   -", fname)
                all_trees = [orig_tree] + [ trees[k] for k in sorted(trees.keys()) ]
                for idx_a in range(len(all_trees)):
                    tree_a = all_trees[idx_a]
                    dig_a = result_line.get_digest(idx_a)
                    if dig_a == ' ':
                        continue
                    # print("# tree_a", tree_a)
                    for idx_b in range(len(all_trees)):
                        tree_b = all_trees[idx_b]
                        dig_b = result_line.get_digest(idx_b)
                        if dig_a == dig_b:
                            continue
                        if dig_b == ' ':
                            continue
                        # print("# tree_b", tree_b)
                        flag     = (fname, dig_a, dig_b)
                        rev_flag = (fname, dig_b, dig_a)
                        if flag in line_flags:
                            pass
                        elif rev_flag in line_flags:
                            pass
                        else:
                            line_flags.append(flag)
                            cmd = (fname, dig_a, dig_b,
                                   tree_a[fname].fpath, tree_b[fname].fpath)
                            diff_commands.append(cmd)
                    break
            del line_flags
        # print("#")

        print_diff_commands(diff_commands)
        sys.exit(0)


#######################################################################


def main(args):
    if (args == []):
        print_help()
        sys.exit(0)

    arg_diff = False
    for i, arg in enumerate(args):
        if arg == '--':
            i += 1
            break
        elif arg == '--help':
            print_help()
            sys.exit(0)
        elif arg == '--diff':
            arg_diff = True
        elif arg.startswith('--'):
            raise ValueError("Unhandled command line option '%s'" % arg)
        else:
            assert(arg[:2] != '--')
            break

    dir_list = args[i:]
    if False:
        print("Arguments:", dir_list)
        print()

    gphoto_m4_sync(dir_list, arg_diff)


########################################################################


if __name__ == '__main__':
    main(sys.argv[1:])


########################################################################
