#!/usr/bin/env python3

# This file is part of GNU Taler
# (C) 2024 Taler Systems S.A.
#
# GNU Taler is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation; either version 3, or (at your option) any later version.
#
# GNU Taler is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# GNU Taler; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>

import sqlite3
import sys
import os

v = sys.version_info
if v.major < 3 or (v.major == 3 and v.minor < 11):
    print(
        "FATAL: python version >=3.11 required but running on",
        sys.version,
        file=sys.stderr,
    )
    sys.exit(1)

print("started sqlite3 helper at", os.getcwd(), file=sys.stderr)

enable_tracing = False

def trace(*args):
    if not enable_tracing:
        return
    print("HELPER", *args, file=sys.stderr)


CMD_HELLO = 1
CMD_SHUTDOWN = 2
CMD_OPEN = 3
CMD_CLOSE = 4
CMD_PREPARE = 5
CMD_STMT_GET_ALL = 6
CMD_STMT_GET_FIRST = 7
CMD_STMT_RUN = 8
CMD_EXEC = 9

RESP_OK = 1
RESP_FAIL = 2
RESP_ROWLIST = 3
RESP_RUNRESULT = 4
RESP_STMT = 5


TAG_NULL = 1
TAG_INT = 2
TAG_REAL = 3
TAG_TEXT = 4
TAG_BLOB = 5

cmdstream = open(0, "rb")
respstream = open(1, "wb")

db_handles = dict()

# Since python's sqlite3 library does not support prepared statements,
# we fake it by just storing the string of the statement.
# Internally, the sqlite3 library does its own caching of
# prepared statements.
prep_handles = dict()


def write_resp(req_id, cmd, payload=None):
    trace("sending response to request", req_id)
    outlen = 4 + 4 + 1 + (0 if payload is None else len(payload))
    respstream.write(outlen.to_bytes(4))
    respstream.write(req_id.to_bytes(4))
    respstream.write(cmd.to_bytes(1))
    if payload is not None:
        respstream.write(payload)
    respstream.flush()


dbconn = None


class PacketWriter:
    def __init__(self):
        self.chunks = []

    def write_string(self, s):
        buf = s.encode("utf-8")
        self.write_uint32(len(buf))
        self.write_raw_bytes(buf)

    def write_bytes(self, buf):
        self.write_uint32(len(buf))
        self.write_raw_bytes(buf)

    def write_raw_bytes(self, buf):
        self.chunks.append(buf)

    def write_uint8(self, n):
        self.chunks.append(n.to_bytes(1))

    def write_uint32(self, n):
        self.chunks.append(n.to_bytes(4))

    def write_uint16(self, n):
        self.chunks.append(n.to_bytes(2))

    def write_int64(self, n):
        self.chunks.append(n.to_bytes(8, signed=True))

    def write_rowlist(self, rows, description):
        self.write_uint16(len(rows))
        self.write_uint16(len(description))
        for desc in description:
            col_name = desc[0]
            self.write_string(col_name)

        if len(description) == 0 or len(rows) == 0:
            return

        for row in rows:
            if len(row) != len(description):
                raise Exception("invariant violated")
            for val in row:
                if val is None:
                    self.write_uint8(TAG_NULL)
                elif isinstance(val, str):
                    self.write_uint8(TAG_TEXT)
                    self.write_string(val)
                elif isinstance(val, bytes):
                    self.write_uint8(TAG_BLOB)
                    self.write_bytes(val)
                elif isinstance(val, int):
                    self.write_uint8(TAG_INT)
                    self.write_int64(val)
                else:
                    raise Exception("unknown col type")

    def reap(self):
        return b"".join(self.chunks)


class PacketReader:
    def __init__(self, data):
        self.data = data
        self.pos = 0

    def read_string(self):
        l = self.read_uint32()
        d = self.data[self.pos : self.pos + l]
        self.pos += l
        return d.decode("utf-8")

    def read_blob(self):
        l = self.read_uint32()
        d = self.data[self.pos : self.pos + l]
        self.pos += l
        return d

    def read_uint16(self):
        d = int.from_bytes(self.data[self.pos : self.pos + 2])
        self.pos += 2
        return d

    def read_uint32(self):
        d = int.from_bytes(self.data[self.pos : self.pos + 4])
        self.pos += 4
        return d

    def read_int64(self):
        d = int.from_bytes(self.data[self.pos : self.pos + 8], signed=True)
        self.pos += 8
        return d

    def read_uint8(self):
        d = self.data[self.pos]
        self.pos += 1
        return d

    def read_params(self):
        num_args = pr.read_uint16()
        params = dict()
        for x in range(num_args):
            name = pr.read_string()
            tag = pr.read_uint8()
            if tag == TAG_NULL:
                params[name] = None
                continue
            if tag == TAG_INT:
                params[name] = pr.read_int64()
                continue
            if tag == TAG_TEXT:
                params[name] = pr.read_string()
                continue
            if tag == TAG_BLOB:
                params[name] = pr.read_blob()
                continue
            raise Exception("tag not understood")
        return params


def read_exactly(n):
    buf = cmdstream.read(n)
    if len(buf) != n:
        raise Exception("incomplete message")
    return buf


def handle_query_failure(req_id, e):
    pw = PacketWriter()
    pw.write_string(str(e))
    pw.write_uint16(e.sqlite_errorcode)
    pw.write_string(e.sqlite_errorname)
    write_resp(req_id, RESP_FAIL, pw.reap())


while True:
    trace("reading command")
    buf_sz = cmdstream.read(4)
    if len(buf_sz) == 0:
        trace("end of input reached")
        sys.exit(0)
    elif len(buf_sz) != 4:
        raise Exception("incomplete message")
    size = int.from_bytes(buf_sz)
    req_id = int.from_bytes(read_exactly(4))
    rest = read_exactly(size - 8)
    pr = PacketReader(rest)
    cmd = pr.read_uint8()
    trace("received command:", cmd, "request_id:", req_id)

    if cmd == CMD_HELLO:
        write_resp(req_id, RESP_OK)
        continue
    if cmd == CMD_OPEN:
        # open
        if dbconn is not None:
            raise Exception("DB already connected")
        db_handle = pr.read_uint16()
        filename = pr.read_string()
        # This only works in python>=3.12
        # dbconn = sqlite3.connect(filename, autocommit=True, isolation_level=None)
        dbconn = sqlite3.connect(filename, isolation_level=None)
        # Make sure we are not in a transaction
        dbconn.commit()
        db_handles[db_handle] = dbconn
        write_resp(req_id, RESP_OK)
        continue
    if cmd == CMD_CLOSE:
        # close
        dbconn.close()
        write_resp(req_id, RESP_OK)
        continue
    if cmd == CMD_PREPARE:
        db_id = pr.read_uint16()
        prep_id = pr.read_uint16()
        sql = pr.read_string()
        prep_handles[prep_id] = (dbconn, sql)
        write_resp(req_id, RESP_OK)
        continue
    if cmd == CMD_STMT_GET_ALL:
        prep_id = pr.read_uint16()
        params = pr.read_params()
        dbconn, stmt = prep_handles[prep_id]
        cursor = dbconn.cursor()
        try:
            res = cursor.execute(stmt, params)
            rows = cursor.fetchall()
        except sqlite3.Error as e:
            handle_query_failure(req_id, e)
            continue
        pw = PacketWriter()
        pw.write_rowlist(rows, cursor.description)
        write_resp(req_id, RESP_ROWLIST, pw.reap())
        continue
    if cmd == CMD_STMT_GET_FIRST:
        prep_id = pr.read_uint16()
        params = pr.read_params()
        dbconn, stmt = prep_handles[prep_id]
        cursor = dbconn.cursor()
        try:
            res = cursor.execute(stmt, params)
            row = cursor.fetchone()
        except sqlite3.Error as e:
            handle_query_failure(req_id, e)
            continue
        pw = PacketWriter()
        rows = [row] if row is not None else []
        pw.write_rowlist(rows, cursor.description)
        write_resp(req_id, RESP_ROWLIST, pw.reap())
        continue
    if cmd == CMD_STMT_RUN:
        trace("running statement")
        prep_id = pr.read_uint16()
        params = pr.read_params()
        dbconn, stmt = prep_handles[prep_id]
        cursor = dbconn.cursor()
        try:
            res = cursor.execute(stmt, params)
        except sqlite3.Error as e:
            trace("got sqlite error")
            handle_query_failure(req_id, e)
            continue
        trace("running query succeeded")
        if cursor.lastrowid is None:
            write_resp(req_id, RESP_OK)
        else:
            pw = PacketWriter()
            pw.write_int64(cursor.lastrowid)
            payload = pw.reap()
            write_resp(req_id, RESP_RUNRESULT, payload)
        continue
    if cmd == CMD_EXEC:
        db_id = pr.read_uint16()
        sql = pr.read_string()
        dbconn = db_handles[db_id]
        try:
            dbconn.executescript(sql)
        except sqlite3.Error as e:
            handle_query_failure(req_id, e)
            continue
        write_resp(req_id, RESP_OK)
        continue

    print("unknown command", file=sys.stderr)
