#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
@author: Tomas Hozza <thozza@redhat.com>
@author: Pavel Šimerda <psimerda@redhat.com>
"""

from gi.repository import NMClient
import os, sys, shutil, glob, subprocess
import logging, logging.handlers
import socket, struct
import lockfile

DEVNULL = open("/dev/null", "wb")

log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(logging.handlers.SysLogHandler())
log.addHandler(logging.StreamHandler())

# NetworkManager reportedly doesn't pass the PATH environment variable.
os.environ['PATH'] = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"

class UserError(Exception):
    pass

class Config:
    """Global configuration options"""

    path = "/etc/dnssec-trigger/dnssec.conf"
    validate_connection_provided_zones = True
    add_wifi_provided_zones = False

    def __init__(self):
        try:
            with open(self.path) as config_file:
                for line in config_file:
                    if '=' in line:
                        option, value = [part.strip() for part in line.split("=", 1)]
                        if option == "validate_connection_provided_zones":
                            self.validate_connection_provided_zones = (value == "yes")
                        elif option == "add_wifi_provided_zones":
                            self.add_wifi_provided_zones = (value == "yes")
        except IOError:
            pass
        log.debug(self)

    def __repr__(self):
        return "<Config validate_connection_provided_zones={validate_connection_provided_zones} add_wifi_provided_zones={add_wifi_provided_zones}>".format(**vars(self))

class ConnectionList:
    """List of NetworkManager active connections"""

    nm_connections = None

    def __init__(self, client, only_default=False, skip_wifi=False):
        # Cache the active connection list in the class
        if not client.get_manager_running():
            raise UserError("NetworkManager is not running.")
        if self.nm_connections is None:
            self.__class__.nm_connections = client.get_active_connections()
        self.skip_wifi = skip_wifi
        self.only_default = only_default
        log.debug(self)

    def __repr__(self):
        return "<ConnectionList(only_default={only_default}, skip_wifi={skip_wifi}, connections={})>".format(list(self), **vars(self))

    def __iter__(self):
        for item in self.nm_connections:
            connection = Connection(item)
            # Skip connections that should be ignored
            if connection.ignore:
                continue
            # Skip connections without servers
            if not connection.servers:
                continue
            # Skip WiFi connections if appropriate
            if self.skip_wifi and connection.is_wifi:
                continue
            # Skip non-default connections if appropriate
            if self.only_default and not connection.is_default:
                continue
            yield connection

    def get_zone_connection_mapping(self):
        result = {}
        for connection in self:
            for zone in connection.zones:
                if zone in result:
                    result[zone] = self.get_preferred_connection(result[zone], connection)
                else:
                    result[zone] = connection
        return result

    @staticmethod
    def get_preferred_connection(first, second):
        # Prefer VPN connection
        if second.is_vpn and not first.is_vpn:
            return second
        if first.is_vpn and not second.is_vpn:
            return first
        # Prefer default connection
        if second.is_default and not first.is_default:
            return second
        if first.is_default and not second.is_default:
            return first
        # Prefer first connection
        return first

class Connection:
    """Representation of a NetworkManager active connection"""

    def __init__(self, connection):
        devices = connection.get_devices()

        if 'get_vpn_state' in dir(connection):
            self.type = "vpn"
        elif not devices:
            self.type = "ignore"
        elif devices[0].get_device_type().value_name == "NM_DEVICE_TYPE_WIFI":
            self.type = "wifi"
        else:
            self.type = "other"

        self.is_default = bool(connection.get_default() or connection.get_default6())
        self.uuid = connection.get_uuid()

        self.zones = []
        try:
            self.zones += connection.get_ip4_config().get_domains()
        except AttributeError:
            pass
        try:
            self.zones += connection.get_ip6_config().get_domains()
        except AttributeError:
            pass

        self.servers = []
        try:
            self.servers += [self.ip4_to_str(server) for server in connection.get_ip4_config().get_nameservers()]
        except AttributeError:
            pass
        try:
            self.servers += [self.ip6_to_str(connection.get_ip6_config().get_nameserver(i))
                    for i in range(connection.get_ip6_config().get_num_nameservers())]
        except AttributeError:
            pass

    def __repr__(self):
        return "<Connection(uuid={uuid}, type={type}, default={is_default}, zones={zones}, servers={servers})>".format(**vars(self))

    @staticmethod
    def ip4_to_str(ip4):
        """Converts IPv4 address from integer to string."""

        return socket.inet_ntop(socket.AF_INET, struct.pack("=I", ip4))

    @staticmethod
    def ip6_to_str(ip6):
        """Converts IPv6 address from integer to string."""

        return socket.inet_ntop(socket.AF_INET6, ip6)

    @property
    def ignore(self):
        return self.type == "ignore"

    @property
    def is_vpn(self):
        return self.type == "vpn"

    @property
    def is_wifi(self):
        return self.type == "wifi"

class UnboundZoneConfig:
    """A dictionary-like proxy object for Unbound's forward zone configuration."""

    def __init__(self):
        subprocess.check_call(["unbound-control", "status"], stdout=DEVNULL, stderr=DEVNULL)
        self.cache = {}
        for line in subprocess.check_output(["unbound-control", "list_forwards"]).decode().split('\n'):
            if line:
                fields = line.split(" ")
                name = fields.pop(0)[:-1]
                if fields[0] == 'IN':
                    fields.pop(0)
                if fields.pop(0) in ('forward', 'forward:'):
                    fields.pop(0)
                secure = False
                if fields and fields[0] == '+i':
                    secure = True
                    fields.pop(0)
                self.cache[name] = set(fields), secure
        log.debug(self)

    def __repr__(self):
        return "<UnboundZoneConfig(data={cache})>".format(**vars(self))

    def __iter__(self):
        return iter(self.cache)

    def add(self, zone, servers, secure):
        """Install a forward zone into Unbound."""

        self._commit(zone, set(servers), secure)

        log.info("Connection provided zone '{}' ({}): {}".format(
            zone, "validated" if secure else "insecure", ', '.join(servers)))

    def remove(self, zone):
        """Remove a forward zone from Unbound."""

        self._commit(zone, None, None)

    def _commit(self, name, servers, secure):
        # Check the list of servers.
        #
        # Older versions of unbound don't print +i and so we can't distinguish
        # secure and insecure zones properly. Therefore we need to ignore the
        # insecure flag which leads to not being able to switch the zone
        # between secure and insecure unless it's removed or its servers change.
        if self.cache.get(name, [None])[0] == servers:
            log.debug("Connection provided zone '{}' already set to {} ({})".format(name, servers, 'secure' if servers else 'insecure'))
            return

        if servers:
            self.cache[name] = servers, secure
            self._control(["forward_add"] + ([] if secure else ["+i"]) + [name] + list(servers))
        else:
            del self.cache[name]
            self._control(["forward_remove", name])
        self._control(["flush_zone", name])
        self._control(["flush_requestlist"])

        log.debug(self)

    def _control(self, args):
        subprocess.check_call(["unbound-control"] + args, stdout=DEVNULL, stderr=DEVNULL)

class Store:
    """A proxy object to access stored zones or global servers."""

    def __init__(self, name):
        self.name = name
        self.cache = set()
        self.path = os.path.join("/run/dnssec-trigger", name)
        self.path_tmp = self.path + ".tmp"

        try:
            with open(self.path) as zone_file:
                for line in zone_file:
                    line = line.strip()
                    if line:
                        self.cache.add(line)
        except IOError:
            pass
        log.debug(self)

    def __repr__(self):
        return "<Store(name={name}, {cache})>".format(**vars(self))

    def __iter__(self):
        # Don't return the set itself, as it doesn't support update during
        # iteration.
        return iter(list(self.cache))

    def __bool__(self):
        return bool(self.cache)

    def add(self, zone):
        """Add zone to the cache."""

        self.cache.add(zone)
        log.debug(self)

    def update(self, zones):
        """Commit a new set of items and return True when it differs"""

        zones = set(zones)

        if zones != self.cache:
            self.cache = set(zones)
            log.debug(self)
            return True

        return False

    def remove(self, zone):
        """Remove zone from the cache."""

        self.cache.remove(zone)
        log.debug(self)

    def commit(self):
        """Write data back to disk."""

        # We don't use os.makedirs(..., exist_ok=True) to ensure Python 2 compatibility
        dirname = os.path.dirname(self.path_tmp)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        with open(self.path_tmp, "w") as zone_file:
            for zone in self.cache:
                zone_file.write("{}\n".format(zone))
        os.rename(self.path_tmp, self.path)

class GlobalForwarders:
    def __init__(self):
        self.cache = set()
        try:
            with open(self.path) as zone_file:
                for line in zone_file:
                    line = line.strip()
                    if line:
                        self.cache.add(line)
        except IOError:
            pass

class Application:
    def __init__(self, argv):
        if len(argv) > 1 and argv[1] == '--debug':
            argv.pop(1)
            log.setLevel(logging.DEBUG)
        if len(argv) > 1 and argv[1] == '--async':
            argv.pop(1)
            if os.fork():
                sys.exit()
        if len(argv) < 2 or not argv[1].startswith('--'):
            self.usage()
        try:
            self.method = getattr(self, "run_" + argv[1][2:].replace('-', '_'))
        except AttributeError:
            self.usage()
        self.config = Config()
        self.client = NMClient.Client()

        self.resolvconf = "/etc/resolv.conf"
        self.resolvconf_backup = "/run/dnssec-trigger/resolv.conf.bak"

    def nm_handles_resolv_conf(self):
        if not self.client.get_manager_running():
            log.debug("NetworkManager is not running")
            return False
        try:
            with open("/etc/NetworkManager/NetworkManager.conf") as nm_config_file:
                for line in nm_config_file:
                    if line.strip() in ("dns=none", "dns=unbound"):
                        log.debug("NetworkManager doesn't handle /etc/resolv.conf")
                        return False
        except IOError:
            pass
        log.debug("NetworkManager handles /etc/resolv.conf")
        return True

    def usage(self):
        raise UserError("Usage: dnssec-trigger-script [--debug] [--async] --prepare|--update|--update-global-forwarders|--update-connection-zones|--cleanup")

    def run(self):
        with lockfile.FileLock("/run/dnssec-trigger/dnssec-trigger"):
            log.debug("Running: {}".format(self.method.__name__))
            self.method()

    def run_prepare(self):
        """Prepare for dnssec-trigger."""

        if not self.nm_handles_resolv_conf():
            log.info("Backing up /etc/resolv.conf")
            shutil.copy(self.resolvconf, self.resolvconf_backup)

    def run_cleanup(self):
        """Clean up after dnssec-trigger."""

        stored_zones = Store('zones')
        unbound_zones = UnboundZoneConfig()

        # provide upgrade path for previous versions
        old_zones = glob.glob("/run/dnssec-trigger/????????-????-????-????-????????????")
        if old_zones:
            log.info("Reading zones from the legacy zone store")
            with open("/run/dnssec-trigger/zones", "a") as target:
                for filename in old_zones:
                    with open(filename) as source:
                        log.debug("Reading zones from {}".format(filename))
                        for line in source:
                            stored_zones.add(line.strip())
                        os.remove(filename)

        log.debug("clearing unbound configuration")
        for zone in stored_zones:
            unbound_zones.remove(zone)
            stored_zones.remove(zone)
        stored_zones.commit()

        log.debug("recovering /etc/resolv.conf")
        subprocess.check_call(["chattr", "-i", "/etc/resolv.conf"])
        if not self.nm_handles_resolv_conf():
            try:
                shutil.copy(self.resolvconf_backup, self.resolvconf)
            except IOError as error:
                log.warning("Cannot restore resolv.conf from {!r}: {}".format(self.resolvconf_backup, error.strerror))
        # NetworkManager currently doesn't support explicit /etc/resolv.conf
        # write out. For now we simply restart the daemon.
        elif os.path.exists("/sys/fs/cgroup/systemd"):
            subprocess.check_call(["systemctl", "--ignore-dependencies", "try-restart", "NetworkManager.service"])
        else:
            subprocess.check_call(["/etc/init.d/NetworkManager", "restart"])

    def run_update(self):
        self.run_update_global_forwarders()
        self.run_update_connection_zones()

    def run_update_global_forwarders(self):
        """Configure global forwarders using dnssec-trigger-control."""

        subprocess.check_call(["dnssec-trigger-control", "status"], stdout=DEVNULL, stderr=DEVNULL)

        default_connections = ConnectionList(self.client, only_default=True)
        servers = Store('servers')

        if servers.update(sum((connection.servers for connection in default_connections), [])):
            subprocess.check_call(["unbound-control", "flush_zone", "."])
            subprocess.check_call(["dnssec-trigger-control", "submit"] + list(servers))
            servers.commit()
        log.info("Global forwarders: {}".format(' '.join(servers)))

    def run_update_connection_zones(self):
        """Configures forward zones in the unbound using unbound-control."""

        connections = ConnectionList(self.client, skip_wifi=not self.config.add_wifi_provided_zones).get_zone_connection_mapping()
        unbound_zones = UnboundZoneConfig()
        stored_zones = Store('zones')

        # The purpose of the zone store is to keep the list of Unbound zones
        # that are managed by dnssec-trigger-script. We don't want to track
        # zones accoss Unbound restarts. We want to clear any Unbound zones
        # that are no longer active in NetworkManager.
        log.debug("removing stored zones not present in both unbound and an active connection")
        for zone in stored_zones:
            if zone not in unbound_zones:
                stored_zones.remove(zone)
            elif zone not in connections:
                unbound_zones.remove(zone)
                stored_zones.remove(zone)

        # We need to install zones that are not yet in Unbound. We also need to
        # reinstall zones that are already managed by dnssec-trigger in case their
        # list of nameservers was changed.
        #
        # TODO: In some cases, we don't seem to flush Unbound cache properly,
        # even when Unbound is restarted (and dnssec-trigger as well, because
        # of dependency).
        log.debug("installing connection provided zones")
        for zone in connections:
            if zone in stored_zones or zone not in unbound_zones:
                unbound_zones.add(zone, connections[zone].servers, secure=self.config.validate_connection_provided_zones)
                stored_zones.add(zone)

        stored_zones.commit()

if __name__ == "__main__":
    try:
        Application(sys.argv).run()
    except UserError as error:
        log.error(error)
        exit(1)
    except subprocess.CalledProcessError as error:
        if len(error.cmd) == 2 and error.cmd[0].endswith('-control') and error.cmd[1] == "status":
            log.error("Cannot connect to {}.".format(error.cmd[0][:-8]))
            exit(1)
        else:
            raise
