(***********************************************************************)
(*                                                                     *)
(*                      The Cryptokit library                          *)
(*                                                                     *)
(*            Xavier Leroy, projet Cristal, INRIA Rocquencourt         *)
(*                                                                     *)
(*  Copyright 2002 Institut National de Recherche en Informatique et   *)
(*  en Automatique.  All rights reserved.  This file is distributed    *)
(*  under the terms of the GNU Library General Public License, with    *)
(*  the special exception on linking described in file LICENSE.        *)
(*                                                                     *)
(***********************************************************************)

module Bn = CryptokitBignum

(* Utilities *)

let seq_equal (len: 'a -> int) (get: 'a -> int -> char) (s1: 'a) (s2: 'a) =
  let l = len s1 in
  let rec equal i accu =
    if i >= l
    then accu = 0
    else equal (i + 1)
               (accu lor ((Char.code (get s1 i)) lxor (Char.code (get s2 i))))
  in
    l = len s2 && equal 0 0

let string_equal = seq_equal String.length String.get
let bytes_equal = seq_equal Bytes.length Bytes.get

let wipe_bytes s = Bytes.fill s 0 (Bytes.length s) '\000'
let wipe_string s = wipe_bytes (Bytes.unsafe_of_string s)

let shl1_bytes src soff dst doff len =
  let rec shl1 carry i =
    if i >= 0 then begin
      let n = Char.code (Bytes.get src (soff + i)) in
      Bytes.set dst (doff + i) (Char.unsafe_chr ((n lsl 1) lor carry));
      shl1 (n lsr 7) (i - 1)
    end
  in shl1 0 (len - 1)

  let mod_power a b c =
    Bn.to_bytes ~numbits:(String.length c * 8)
      (Z.powm_sec (Bn.of_bytes a) (Bn.of_bytes b) (Bn.of_bytes c))
  let mod_mult a b c =
    Bn.to_bytes ~numbits:(String.length c * 8)
      (Bn.mulm (Bn.of_bytes a) (Bn.of_bytes b) (Bn.of_bytes c))

(* Error reporting *)

type error =
  | Wrong_key_size
  | Wrong_IV_size
  | Wrong_data_length
  | Bad_padding
  | Output_buffer_overflow
  | Incompatible_block_size
  | Number_too_long
  | Seed_too_short
  | Message_too_long
  | Bad_encoding
  | Compression_error of string * string
  | No_entropy_source
  | Entropy_source_closed
  | Compression_not_supported
  | Invalid_point

let describe_error = function
  | Wrong_key_size -> "wrong key size"
  | Wrong_IV_size -> "wrong IV size"
  | Wrong_data_length -> "wrong data length"
  | Bad_padding -> "bad padding"
  | Output_buffer_overflow -> "output buffer overflow"
  | Incompatible_block_size -> "incompatible block size"
  | Number_too_long -> "number too long"
  | Seed_too_short -> "seed too short"
  | Message_too_long -> "message too long"
  | Bad_encoding -> "bad encoding"
  | Compression_error(a, b) ->
      Printf.sprintf "compression error %s:%s" a b
  | No_entropy_source -> "no entropy source"
  | Entropy_source_closed -> "entropy source closed"
  | Compression_not_supported -> "compression not supported"
  | Invalid_point -> "point is not on elliptic curve"

exception Error of error

let _ = Callback.register_exception "Cryptokit.Error" (Error Wrong_key_size)

let _ = Printexc.register_printer (function
            | Error e -> Some(describe_error e)
            | _ -> None)

(* Interface with C *)

type dir = Encrypt | Decrypt

external xor_bytes: bytes -> int -> bytes -> int -> int -> unit = "caml_xor_string"
external xor_string: string -> int -> bytes -> int -> int -> unit = "caml_xor_string"
external aes_cook_encrypt_key : string -> bytes = "caml_aes_cook_encrypt_key"
external aes_cook_decrypt_key : string -> bytes = "caml_aes_cook_decrypt_key"
external aes_encrypt : bytes -> bytes -> int -> bytes -> int -> unit = "caml_aes_encrypt"
external aes_decrypt : bytes -> bytes -> int -> bytes -> int -> unit = "caml_aes_decrypt"
external blowfish_cook_key : string -> bytes = "caml_blowfish_cook_key"
external blowfish_encrypt : bytes -> bytes -> int -> bytes -> int -> unit = "caml_blowfish_encrypt"
external blowfish_decrypt : bytes -> bytes -> int -> bytes -> int -> unit = "caml_blowfish_decrypt"
external des_cook_key : string -> int -> dir -> bytes = "caml_des_cook_key"
external des_transform : bytes -> bytes -> int -> bytes -> int -> unit = "caml_des_transform"
external arcfour_cook_key : string -> bytes = "caml_arcfour_cook_key"
external arcfour_transform : bytes -> bytes -> int -> bytes -> int -> int -> unit = "caml_arcfour_transform_bytecode" "caml_arcfour_transform"
external chacha20_cook_key : string -> bytes -> int64 -> bytes = "caml_chacha20_cook_key"
external chacha20_transform : bytes -> bytes -> int -> bytes -> int -> int -> unit = "caml_chacha20_transform_bytecode" "caml_chacha20_transform"
external chacha20_extract : bytes -> bytes -> int -> int -> unit = "caml_chacha20_extract"

external sha1_init: unit -> bytes = "caml_sha1_init"
external sha1_update: bytes -> bytes -> int -> int -> unit = "caml_sha1_update"
external sha1_final: bytes -> string = "caml_sha1_final"
external sha256_init: unit -> bytes = "caml_sha256_init"
external sha224_init: unit -> bytes = "caml_sha224_init"
external sha256_update: bytes -> bytes -> int -> int -> unit = "caml_sha256_update"
external sha256_final: bytes -> string = "caml_sha256_final"
external sha224_final: bytes -> string = "caml_sha224_final"
external sha512_init: unit -> bytes = "caml_sha512_init"
external sha384_init: unit -> bytes = "caml_sha384_init"
external sha512_256_init: unit -> bytes = "caml_sha512_256_init"
external sha512_224_init: unit -> bytes = "caml_sha512_224_init"
external sha512_update: bytes -> bytes -> int -> int -> unit = "caml_sha512_update"
external sha512_final: bytes -> string = "caml_sha512_final"
external sha384_final: bytes -> string = "caml_sha384_final"
external sha512_256_final: bytes -> string = "caml_sha512_256_final"
external sha512_224_final: bytes -> string = "caml_sha512_224_final"
type sha3_context
external sha3_init: int -> sha3_context = "caml_sha3_init"
external sha3_absorb: sha3_context -> bytes -> int -> int -> unit = "caml_sha3_absorb"
external sha3_extract: bool -> sha3_context -> string = "caml_sha3_extract"
external sha3_wipe: sha3_context -> unit = "caml_sha3_wipe"
external ripemd160_init: unit -> bytes = "caml_ripemd160_init"
external ripemd160_update: bytes -> bytes -> int -> int -> unit = "caml_ripemd160_update"
external ripemd160_final: bytes -> string = "caml_ripemd160_final"
external md5_init: unit -> bytes = "caml_md5_init"
external md5_update: bytes -> bytes -> int -> int -> unit = "caml_md5_update"
external md5_final: bytes -> string = "caml_md5_final"
external blake2b_init: int -> string -> bytes = "caml_blake2b_init"
external blake2b_update: bytes -> bytes -> int -> int -> unit = "caml_blake2b_update"
external blake2b_final: bytes -> int -> string = "caml_blake2b_final"
external blake2s_init: int -> string -> bytes = "caml_blake2s_init"
external blake2s_update: bytes -> bytes -> int -> int -> unit = "caml_blake2s_update"
external blake2s_final: bytes -> int -> string = "caml_blake2s_final"
type ghash_context
external ghash_init: bytes -> ghash_context = "caml_ghash_init"
external ghash_mult: ghash_context -> bytes -> unit = "caml_ghash_mult"
external poly1305_init: bytes -> bytes = "caml_poly1305_init"
external poly1305_update: bytes -> bytes -> int -> int -> unit = "caml_poly1305_update"
external poly1305_final: bytes -> string = "caml_poly1305_final"
external siphash_init: string -> int -> bytes = "caml_siphash_init"
external siphash_update: bytes -> bytes -> int -> int -> unit = "caml_siphash_update"
external siphash_final: bytes -> int -> string = "caml_siphash_final"
type blake3_context
external blake3_init: string -> blake3_context = "caml_blake3_init"
external blake3_update: blake3_context -> bytes -> int -> int -> unit = "caml_blake3_update"
external blake3_final: blake3_context -> int -> string = "caml_blake3_extract"
external blake3_wipe: blake3_context -> unit = "caml_blake3_wipe"

(* Abstract transform type *)

class type transform =
  object
    method input_block_size: int
    method output_block_size: int

    method put_substring: bytes -> int -> int -> unit
    method put_string: string -> unit
    method put_char: char -> unit
    method put_byte: int -> unit

    method finish: unit
    method flush: unit

    method available_output: int

    method get_string: string
    method get_substring: bytes * int * int
    method get_char: char
    method get_byte: int

    method wipe: unit
  end

let transform_string tr s =
  tr#put_string s;
  tr#finish;
  let r = tr#get_string in tr#wipe; r

let transform_channel tr ?len ic oc =
  let ibuf = Bytes.create 256 in
  let rec transf_to_eof () =
    let r = input ic ibuf 0 256 in
    if r > 0 then begin
      tr#put_substring ibuf 0 r;
      let (obuf, opos, olen) = tr#get_substring in
      output oc obuf opos olen;
      transf_to_eof()
    end
  and transf_bounded numleft =
    if numleft > 0 then begin
      let r = input ic ibuf 0 (min 256 numleft) in
      if r = 0 then raise End_of_file;
      tr#put_substring ibuf 0 r;
      let (obuf, opos, olen) = tr#get_substring in
      output oc obuf opos olen;
      transf_bounded (numleft - r)
    end in
  begin match len with
      None -> transf_to_eof ()
    | Some l -> transf_bounded l
  end;
  wipe_bytes ibuf;
  tr#finish;
  let (obuf, opos, olen) = tr#get_substring in
  output oc obuf opos olen;
  tr#wipe  

class compose (tr1 : transform) (tr2 : transform) =
  object(self)
    method input_block_size = tr1#input_block_size
    method output_block_size = tr2#output_block_size

    method put_substring buf ofs len =
      tr1#put_substring buf ofs len; self#transfer
    method put_string s =
      tr1#put_string s; self#transfer
    method put_char c =
      tr1#put_char c; self#transfer
    method put_byte b =
      tr1#put_byte b; self#transfer

    method private transfer =
      let (buf, ofs, len) = tr1#get_substring in
      tr2#put_substring buf ofs len

    method available_output = tr2#available_output
    method get_string = tr2#get_string
    method get_substring = tr2#get_substring
    method get_char = tr2#get_char
    method get_byte = tr2#get_byte

    method flush = tr1#flush; self#transfer; tr2#flush
    method finish = tr1#finish; self#transfer; tr2#finish

    method wipe = tr1#wipe; tr2#wipe
  end

let compose tr1 tr2 = new compose tr1 tr2

class type hash =
  object
    method hash_size: int
    method add_substring: bytes -> int -> int -> unit
    method add_string: string -> unit
    method add_char: char -> unit
    method add_byte: int -> unit
    method result: string
    method wipe: unit
  end

let hash_string hash s =
  hash#add_string s;
  let r = hash#result in
  hash#wipe;
  r

let hash_channel hash ?len ic =
  let ibuf = Bytes.create 256 in
  let rec hash_to_eof () =
    let r = input ic ibuf 0 256 in
    if r > 0 then begin
      hash#add_substring ibuf 0 r;
      hash_to_eof()
    end
  and hash_bounded numleft =
    if numleft > 0 then begin
      let r = input ic ibuf 0 (min 256 numleft) in
      if r = 0 then raise End_of_file;
      hash#add_substring ibuf 0 r;
      hash_bounded (numleft - r)
    end in
  begin match len with
      None -> hash_to_eof ()
    | Some l -> hash_bounded l
  end;
  wipe_bytes ibuf;
  let res = hash#result in
  hash#wipe;
  res

class type authenticated_transform =
  object
    method input_block_size: int
    method output_block_size: int
    method tag_size: int

    method put_substring: bytes -> int -> int -> unit
    method put_string: string -> unit
    method put_char: char -> unit
    method put_byte: int -> unit

    method finish_and_get_tag: string

    method available_output: int

    method get_string: string
    method get_substring: bytes * int * int
    method get_char: char
    method get_byte: int

    method wipe: unit
  end

let auth_transform_string_detached tr s =
  tr#put_string s;
  let tag = tr#finish_and_get_tag in
  let txt = tr#get_string in
  tr#wipe;
  (txt, tag)

let auth_transform_string tr s =
  let (txt, tag) = auth_transform_string_detached tr s in
  txt ^ tag

let auth_check_transform_string tr s =
  let ls = String.length s in
  let lt = tr#tag_size in
  if ls < lt then raise (Error Wrong_data_length);
  tr#put_string (String.sub s 0 (ls - lt));
  let tag = tr#finish_and_get_tag in
  let res =
    if string_equal tag (String.sub s (ls - lt) lt)
    then Some (tr#get_string)
    else None in
  tr#wipe; res

(* Generic handling of output buffering *)

class buffered_output initial_buffer_size =
  object(self)
    val mutable obuf = Bytes.create initial_buffer_size
    val mutable obeg = 0
    val mutable oend = 0

    method private ensure_capacity n =
      let len = Bytes.length obuf in
      if oend + n > len then begin
        if oend - obeg + n < len then begin
          Bytes.blit obuf obeg obuf 0 (oend - obeg);
          oend <- oend - obeg;
          obeg <- 0
        end else begin
          let newlen = ref (2 * len) in
          while oend - obeg + n > (!newlen) do
            newlen := (!newlen) * 2
          done;
          if (!newlen) > Sys.max_string_length then begin
            if (oend - obeg + n) <= Sys.max_string_length then
              newlen := Sys.max_string_length
            else
              raise (Error Output_buffer_overflow)
          end;
          let newbuf = Bytes.create (!newlen) in
          Bytes.blit obuf obeg newbuf 0 (oend - obeg);
          obuf <- newbuf;
          oend <- oend - obeg;
          obeg <- 0
        end
      end

    method available_output = oend - obeg

    method get_substring =
      let res = (obuf, obeg, oend - obeg) in obeg <- 0; oend <- 0; res

    method get_string =
      let res = Bytes.sub_string obuf obeg (oend - obeg) in obeg <- 0; oend <- 0; res

    method get_char =
      if obeg >= oend then raise End_of_file;
      let r = Bytes.get obuf obeg in
      obeg <- obeg + 1;
      r

    method get_byte =
      Char.code self#get_char          

    method wipe =
      wipe_bytes obuf
  end

(* Combining a transform and a hash to get an authenticated transform *)

class transform_then_hash (tr: transform) (h: hash) =
  object(self)
    inherit buffered_output 256 as output_buffer

    method private transfer =
      let (buf, ofs, len) = tr#get_substring in
      h#add_substring buf ofs len;
      self#ensure_capacity len;
      Bytes.blit buf ofs obuf oend len;
      oend <- oend + len

    method input_block_size = tr#input_block_size
    method output_block_size = tr#output_block_size
    method tag_size = h#hash_size

    method put_substring buf ofs len =
      tr#put_substring buf ofs len; self#transfer
    method put_string s =
      tr#put_string s; self#transfer
    method put_char c =
      tr#put_char c; self#transfer
    method put_byte b =
      tr#put_byte b; self#transfer

    method finish_and_get_tag =
      tr#finish; self#transfer; h#result

    method wipe =
      output_buffer#wipe; tr#wipe; h#wipe

end

let transform_then_hash tr h = new transform_then_hash tr h

class transform_and_hash (tr: transform) (h: hash) =
  object(self)
    method input_block_size = tr#input_block_size
    method output_block_size = tr#output_block_size
    method tag_size = h#hash_size

    method put_substring buf ofs len =
      tr#put_substring buf ofs len; h#add_substring buf ofs len
    method put_string s =
      tr#put_string s; h#add_string s
    method put_char c =
      tr#put_char c; h#add_char c
    method put_byte b =
      tr#put_byte b; h#add_byte b

    method finish_and_get_tag =
      tr#finish; h#result

    method wipe =
      tr#wipe; h#wipe

    method available_output = tr#available_output
    method get_substring = tr#get_substring
    method get_string = tr#get_string
    method get_char = tr#get_char
    method get_byte = tr#get_byte
end

let transform_and_hash tr h = new transform_and_hash tr h

(* Padding schemes *)

module Padding = struct

class type scheme =
  object
    method pad: bytes -> int -> unit
    method strip: bytes -> int
  end

class length =
  object
    method pad buffer used =
      let n = Bytes.length buffer - used in
      assert (n > 0 && n < 256);
      Bytes.fill buffer used n (Char.chr n)
    method strip buffer =
      let blocksize = Bytes.length buffer in
      let n = Char.code (Bytes.get buffer (blocksize - 1)) in
      if n = 0 || n > blocksize then raise (Error Bad_padding);
      (* Characters blocksize - n to blocksize - 1 must be equal to n *)
      for i = blocksize - n to blocksize - 2 do
        if Char.code (Bytes.get buffer i) <> n then raise (Error Bad_padding)
      done;
      blocksize - n
  end

let length = new length

class _8000 =
  object
    method pad buffer used =
      Bytes.set buffer used '\128';
      for i = used + 1 to Bytes.length buffer - 1 do
        Bytes.set buffer i '\000'
      done
    method strip buffer =
      let rec strip pos =
        if pos < 0 then raise (Error Bad_padding) else
          match Bytes.get buffer pos with
            '\128' -> pos
          | '\000' -> strip (pos - 1)
          |    _   -> raise (Error Bad_padding)
      in strip (Bytes.length buffer - 1)
  end

let _8000 = new _8000

end

(* Block ciphers *)

module Block = struct

class type block_cipher =
  object
    method blocksize: int
    method transform: bytes -> int -> bytes -> int -> unit
    method wipe: unit
  end

class aes_encrypt key =
  object
    val ckey =
      let kl = String.length key in
      if kl = 16 || kl = 24 || kl = 32
      then aes_cook_encrypt_key key
      else raise(Error Wrong_key_size)
    method blocksize = 16
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 16
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 16
      then invalid_arg "aes#transform";
      aes_encrypt ckey src src_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey;
      Bytes.set ckey (Bytes.length ckey - 1) '\016'
  end

class aes_decrypt key =
  object
    val ckey =
      let kl = String.length key in
      if kl = 16 || kl = 24 || kl = 32
      then aes_cook_decrypt_key key
      else raise(Error Wrong_key_size)
    method blocksize = 16
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 16
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 16
      then invalid_arg "aes#transform";
      aes_decrypt ckey src src_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey;
      Bytes.set ckey (Bytes.length ckey - 1) '\016'
  end

class blowfish_encrypt key =
  object
    val ckey =
      let kl = String.length key in
      if kl >= 4 && kl <= 56
      then blowfish_cook_key key
      else raise(Error Wrong_key_size)
    method blocksize = 8
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 8
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 8
      then invalid_arg "blowfish#transform";
      blowfish_encrypt ckey src src_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey
  end

class blowfish_decrypt key =
  object
    val ckey =
      let kl = String.length key in
      if kl >= 4 && kl <= 56
      then blowfish_cook_key key
      else raise(Error Wrong_key_size)
    method blocksize = 8
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 8
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 8
      then invalid_arg "blowfish#transform";
      blowfish_decrypt ckey src src_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey
  end

class des direction key =
  object
    val ckey =
      if String.length key = 8
      then des_cook_key key 0 direction
      else raise(Error Wrong_key_size)
    method blocksize = 8
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 8
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 8
      then invalid_arg "des#transform";
      des_transform ckey src src_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey
  end

class des_encrypt = des Encrypt
class des_decrypt = des Decrypt

class triple_des_encrypt key =
  let _ =
    let kl = String.length key in
    if kl <> 16 && kl <> 24 then raise (Error Wrong_key_size) in
  let ckey1 =
    des_cook_key key 0 Encrypt in
  let ckey2 =
    des_cook_key key 8 Decrypt in
  let ckey3 =
    if String.length key = 24
    then des_cook_key key 16 Encrypt
    else ckey1 in
  object
    method blocksize = 8
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 8
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 8
      then invalid_arg "triple_des#transform";
      des_transform ckey1 src src_ofs dst dst_ofs;
      des_transform ckey2 dst dst_ofs dst dst_ofs;
      des_transform ckey3 dst dst_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey1;
      wipe_bytes ckey2;
      wipe_bytes ckey3
  end

class triple_des_decrypt key =
  let _ =
    let kl = String.length key in
    if kl <> 16 && kl <> 24 then raise (Error Wrong_key_size) in
  let ckey3 =
    des_cook_key key 0 Decrypt in
  let ckey2 =
    des_cook_key key 8 Encrypt in
  let ckey1 =
    if String.length key = 24
    then des_cook_key key 16 Decrypt
    else ckey3 in
  object
    method blocksize = 8
    method transform src src_ofs dst dst_ofs =
      if src_ofs < 0 || src_ofs > Bytes.length src - 8
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - 8
      then invalid_arg "triple_des#transform";
      des_transform ckey1 src src_ofs dst dst_ofs;
      des_transform ckey2 dst dst_ofs dst dst_ofs;
      des_transform ckey3 dst dst_ofs dst dst_ofs
    method wipe =
      wipe_bytes ckey1;
      wipe_bytes ckey2;
      wipe_bytes ckey3
  end

(* Chaining modes *)

let make_initial_iv blocksize = function
  | None ->
      Bytes.make blocksize '\000'
  | Some s ->
      if String.length s <> blocksize then raise (Error Wrong_IV_size);
      Bytes.of_string s

class cbc_encrypt ?iv:iv_init (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    method blocksize = blocksize
    method transform src src_off dst dst_off =
      xor_bytes src src_off iv 0 blocksize;
      cipher#transform iv 0 dst dst_off;
      Bytes.blit dst dst_off iv 0 blocksize
    method wipe =
      cipher#wipe;
      wipe_bytes iv
  end

class cbc_decrypt ?iv:iv_init (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    val next_iv = Bytes.create blocksize
    method blocksize = blocksize
    method transform src src_off dst dst_off =
      Bytes.blit src src_off next_iv 0 blocksize;
      cipher#transform src src_off dst dst_off;
      xor_bytes iv 0 dst dst_off blocksize;
      Bytes.blit next_iv 0 iv 0 blocksize
    method wipe =
      cipher#wipe;
      wipe_bytes iv;
      wipe_bytes next_iv
  end

class cfb_encrypt ?iv:iv_init chunksize (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  let _ = assert (chunksize > 0 && chunksize <= blocksize) in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    val out = Bytes.create blocksize
    method blocksize = chunksize
    method transform src src_off dst dst_off =
      cipher#transform iv 0 out 0;
      Bytes.blit src src_off dst dst_off chunksize;
      xor_bytes out 0 dst dst_off chunksize;
      Bytes.blit iv chunksize iv 0 (blocksize - chunksize);
      Bytes.blit dst dst_off iv (blocksize - chunksize) chunksize
    method wipe =
      cipher#wipe;
      wipe_bytes iv;
      wipe_bytes out
  end

class cfb_decrypt ?iv:iv_init chunksize (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  let _ = assert (chunksize > 0 && chunksize <= blocksize) in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    val out = Bytes.create blocksize
    method blocksize = chunksize
    method transform src src_off dst dst_off =
      cipher#transform iv 0 out 0;
      Bytes.blit iv chunksize iv 0 (blocksize - chunksize);
      Bytes.blit src src_off iv (blocksize - chunksize) chunksize;
      Bytes.blit src src_off dst dst_off chunksize;
      xor_bytes out 0 dst dst_off chunksize
    method wipe =
      cipher#wipe;
      wipe_bytes iv;
      wipe_bytes out
  end

class ofb ?iv:iv_init chunksize (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  let _ = assert (chunksize > 0 && chunksize <= blocksize) in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    method blocksize = chunksize
    method transform src src_off dst dst_off =
      cipher#transform iv 0 iv 0;
      Bytes.blit src src_off dst dst_off chunksize;
      xor_bytes iv 0 dst dst_off chunksize
    method wipe =
      cipher#wipe;
      wipe_bytes iv
  end

let rec increment_counter c lim pos =
  if pos >= lim then begin
    let i = 1 + Char.code (Bytes.get c pos) in
    Bytes.set c pos (Char.unsafe_chr i);
    if i = 0x100 then increment_counter c lim (pos - 1)
  end

class ctr ?iv:iv_init ?inc (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  let nincr =
    match inc with
    | None -> blocksize
    | Some n -> assert (n > 0 && n <= blocksize); n in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    val out = Bytes.create blocksize
    val mutable max_transf =
      if nincr < 8 then Int64.(shift_left 1L (nincr * 8)) else 0L
    method blocksize = blocksize
    method transform src src_off dst dst_off =
      cipher#transform iv 0 out 0;
      Bytes.blit src src_off dst dst_off blocksize;
      xor_bytes out 0 dst dst_off blocksize;
      increment_counter iv (blocksize - nincr) (blocksize - 1);
      let m = Int64.pred max_transf in
      if m = 0L then raise (Error Message_too_long);
      max_transf <- m
    method wipe =
      cipher#wipe;
      wipe_bytes iv;
      wipe_bytes out
  end

(* Wrapping of a block cipher as a transform *)

class cipher (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  object(self)
    val ibuf = Bytes.create blocksize
    val mutable used = 0

    inherit buffered_output (max 256 (2 * blocksize)) as output_buffer

    method input_block_size = blocksize
    method output_block_size = blocksize

    method put_substring src ofs len =
      if len <= 0 then () else
      if used + len <= blocksize then begin
        (* Just accumulate len characters in ibuf *)
        Bytes.blit src ofs ibuf used len;
        used <- used + len
      end else begin
        (* Fill buffer and run it through cipher *)
        let n = blocksize - used in
        Bytes.blit src ofs ibuf used n;
        self#ensure_capacity blocksize;
        cipher#transform ibuf 0 obuf oend;
        oend <- oend + blocksize;
        used <- 0;
        (* Recurse on remainder of string *)
        self#put_substring src (ofs + n) (len - n)
      end

    method put_string s =
      self#put_substring (Bytes.unsafe_of_string s) 0 (String.length s)

    method put_char c =
      if used < blocksize then begin
        Bytes.set ibuf used c;
        used <- used + 1
      end else begin
        self#ensure_capacity blocksize;
        cipher#transform ibuf 0 obuf oend;
        oend <- oend + blocksize;
        Bytes.set ibuf 0 c;
        used <- 1
      end

    method put_byte b =
      self#put_char (Char.unsafe_chr b)

    method wipe =
      cipher#wipe;
      output_buffer#wipe;
      wipe_bytes ibuf

    method flush =
      if used = 0 then ()
      else if used = blocksize then begin
        self#ensure_capacity blocksize;
        cipher#transform ibuf 0 obuf oend;
        used <- 0;
        oend <- oend + blocksize
      end
      else raise (Error Wrong_data_length)

    method finish =
      self#flush
  end

(* Block cipher with padding *)

class cipher_padded_encrypt (padding : Padding.scheme)
                            (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  object(self)
    inherit cipher cipher
    method input_block_size = 1

    method finish =
      if used >= blocksize then begin
        self#ensure_capacity blocksize;
        cipher#transform ibuf 0 obuf oend;
        oend <- oend + blocksize;
        used <- 0
      end;
      padding#pad ibuf used;
      self#ensure_capacity blocksize;
      cipher#transform ibuf 0 obuf oend;
      oend <- oend + blocksize
  end

class cipher_padded_decrypt (padding : Padding.scheme)
                            (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  object(self)
    inherit cipher cipher
    method output_block_size = 1

    method finish =
      if used <> blocksize then raise (Error Wrong_data_length);
      cipher#transform ibuf 0 ibuf 0;
      let valid = padding#strip ibuf in
      self#ensure_capacity valid;
      Bytes.blit ibuf 0 obuf oend valid;
      oend <- oend + valid
  end

(* Wrapping of a block cipher as a MAC, using CBC mode *)

class mac ?iv:iv_init ?(pad: Padding.scheme option) (cipher : block_cipher) =
  let blocksize = cipher#blocksize in
  object(self)
    val iv = make_initial_iv blocksize iv_init
    val buffer = Bytes.create blocksize
    val mutable used = 0

    method hash_size = blocksize

    method add_substring src src_ofs len =
      let rec add src_ofs len =
        if len <= 0 then () else
        if used + len <= blocksize then begin
          (* Just accumulate len characters in buffer *)
          Bytes.blit src src_ofs buffer used len;
          used <- used + len
        end else begin
          (* Fill buffer and run it through cipher *)
          let n = blocksize - used in
          Bytes.blit src src_ofs buffer used n;
          xor_bytes iv 0 buffer 0 blocksize;
          cipher#transform buffer 0 iv 0;
          used <- 0;
          (* Recurse on remainder of string *)
          add (src_ofs + n) (len - n)
        end
      in add src_ofs len

    method add_string s =
      self#add_substring (Bytes.unsafe_of_string s) 0 (String.length s)

    method add_char c =
      if used < blocksize then begin
        Bytes.set buffer used c;
        used <- used + 1
      end else begin
        xor_bytes iv 0 buffer 0 blocksize;
        cipher#transform buffer 0 iv 0;
        Bytes.set buffer 0 c;
        used <- 1
      end

    method add_byte b =
      self#add_char (Char.unsafe_chr b)

    method wipe =
      cipher#wipe;
      wipe_bytes buffer;
      wipe_bytes iv

    method result =
      if used = blocksize then begin
        xor_bytes iv 0 buffer 0 blocksize;
        cipher#transform buffer 0 iv 0;
        used <- 0
      end;
      begin match pad with
        None ->
          if used <> 0 then raise (Error Wrong_data_length)
      | Some p ->
          p#pad buffer used;
          xor_bytes iv 0 buffer 0 blocksize;
          cipher#transform buffer 0 iv 0;
          used <- 0
      end;
      Bytes.to_string iv
  end

class mac_final_triple ?iv ?pad (cipher1 : block_cipher)
                                (cipher2 : block_cipher)
                                (cipher3 : block_cipher) =
  let _ = if cipher1#blocksize <> cipher2#blocksize
          || cipher2#blocksize <> cipher3#blocksize
          then raise(Error Incompatible_block_size) in
  object
    inherit mac ?iv ?pad cipher1 as super
    method result =
      let r = Bytes.of_string super#result in
      cipher2#transform r 0 r 0;
      cipher3#transform r 0 r 0;
      Bytes.unsafe_to_string r
    method wipe =
      super#wipe; cipher2#wipe; cipher3#wipe
  end

(* Wrapping of a block ciper as a MAC, in CMAC mode (a.k.a. OMAC1) *)

class cmac ?iv:iv_init (cipher : block_cipher) k1 k2 =
  object (self)
    inherit mac ?iv:iv_init cipher as super

    method result =
      let blocksize = cipher#blocksize in
      let k' =
        if used = blocksize then k1 else (Padding._8000#pad buffer used; k2) in
      xor_bytes iv 0 buffer 0 blocksize;
      xor_bytes k' 0 buffer 0 blocksize;
      cipher#transform buffer 0 iv 0;
      used <- 0; (* really useful? *)
      Bytes.to_string iv

    method wipe =
      super#wipe;
      wipe_bytes k1;
      wipe_bytes k2
  end
end

(* Stream ciphers *)

module Stream = struct

class type stream_cipher =
  object
    method transform: bytes -> int -> bytes -> int -> int -> unit
    method wipe: unit
  end

class arcfour key =
  object
    val ckey =
      if String.length key > 0 && String.length key <= 256
      then arcfour_cook_key key
      else raise(Error Wrong_key_size)
    method transform src src_ofs dst dst_ofs len =
      if len < 0
      || src_ofs < 0 || src_ofs > Bytes.length src - len
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - len
      then invalid_arg "arcfour#transform";
      arcfour_transform ckey src src_ofs dst dst_ofs len
    method wipe =
      wipe_bytes ckey
  end

class chacha20 ?iv ?(ctr = 0L) key =
  object
    val ckey =
      if not (String.length key = 16 || String.length key = 32)
      then raise (Error Wrong_key_size);
      let iv =
        match iv with
        | None -> Bytes.make 8 '\000'
        | Some s ->
            if String.length s = 8
            || String.length s = 12 && ctr < 0x1_000_000L
            then Bytes.of_string s
            else raise (Error Wrong_IV_size) in
      chacha20_cook_key key iv ctr
    method transform src src_ofs dst dst_ofs len =
      if len < 0
      || src_ofs < 0 || src_ofs > Bytes.length src - len
      || dst_ofs < 0 || dst_ofs > Bytes.length dst - len
      then invalid_arg "chacha20#transform";
      chacha20_transform ckey src src_ofs dst dst_ofs len
    method wipe =
      wipe_bytes ckey
  end

(* Wrapping of a stream cipher as a cipher *)

class cipher (cipher : stream_cipher) =
  object(self)
    val charbuf = Bytes.create 1

    inherit buffered_output 256 as output_buffer
    method input_block_size = 1
    method output_block_size = 1

    method put_substring src ofs len =
      self#ensure_capacity len;
      cipher#transform src ofs obuf oend len;
      oend <- oend + len

    method put_string s =
      self#put_substring (Bytes.unsafe_of_string s) 0 (String.length s)

    method put_char c =
      Bytes.set charbuf 0 c;
      self#ensure_capacity 1;
      cipher#transform charbuf 0 obuf oend 1;
      oend <- oend + 1

    method put_byte b =
      self#put_char (Char.unsafe_chr b)

    method flush = ()
    method finish = ()

    method wipe =
      cipher#wipe;
      output_buffer#wipe;
      wipe_bytes charbuf
  end

end

(* Hash functions *)

module Hash = struct

class sha1 =
  object(self)
    val context = sha1_init()
    method hash_size = 20
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha1#add_substring";
      sha1_update context src ofs len
    method add_string src =
      sha1_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha1_final context
    method wipe =
      wipe_bytes context
  end

let sha1 () = new sha1

class sha224 =
  object(self)
    val context = sha224_init()
    method hash_size = 24
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha224#add_substring";
      sha256_update context src ofs len
    method add_string src =
      sha256_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha224_final context
    method wipe =
      wipe_bytes context
  end

let sha224 () = new sha224

class sha256 =
  object(self)
    val context = sha256_init()
    method hash_size = 32
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha256#add_substring";
      sha256_update context src ofs len
    method add_string src =
      sha256_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha256_final context
    method wipe =
      wipe_bytes context
  end

let sha256 () = new sha256

class sha384 =
  object(self)
    val context = sha384_init()
    method hash_size = 48
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha384#add_substring";
      sha512_update context src ofs len
    method add_string src =
      sha512_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha384_final context
    method wipe =
      wipe_bytes context
  end

let sha384 () = new sha384

class sha512 =
  object(self)
    val context = sha512_init()
    method hash_size = 64
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha512#add_substring";
      sha512_update context src ofs len
    method add_string src =
      sha512_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha512_final context
    method wipe =
      wipe_bytes context
  end

let sha512 () = new sha512

class sha512_256 =
  object(self)
    val context = sha512_256_init()
    method hash_size = 32
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha512_256#add_substring";
      sha512_update context src ofs len
    method add_string src =
      sha512_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha512_256_final context
    method wipe =
      wipe_bytes context
  end

let sha512_256 () = new sha512_256

class sha512_224 =
  object(self)
    val context = sha512_224_init()
    method hash_size = 28
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "sha512_224#add_substring";
      sha512_update context src ofs len
    method add_string src =
      sha512_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      sha512_224_final context
    method wipe =
      wipe_bytes context
  end

let sha512_224 () = new sha512_224

let sha2 sz =
  match sz with
  | 224 -> new sha224
  | 256 -> new sha256
  | 384 -> new sha384
  | 512 -> new sha512
  |  _  -> raise (Error Wrong_key_size)

class sha3 sz official =
  object(self)
    val context =
      if sz = 224 || sz = 256 || sz = 384 || sz = 512
      then sha3_init sz
      else raise (Error Wrong_key_size)
    method hash_size = sz / 8
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg ((if official then "sha3" else "keccak")^"#add_substring");
      sha3_absorb context src ofs len
    method add_string src =
      sha3_absorb context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result = sha3_extract official context
    method wipe =
      sha3_wipe context
  end

let sha3 sz = new sha3 sz true

let keccak sz = new sha3 sz false

class ripemd160 =
  object(self)
    val context = ripemd160_init()
    method hash_size = 32
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "ripemd160#add_substring";
      ripemd160_update context src ofs len
    method add_string src =
      ripemd160_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      ripemd160_final context
    method wipe =
      wipe_bytes context
  end

let ripemd160 () = new ripemd160

class md5 =
  object(self)
    val context = md5_init()
    method hash_size = 16
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "md5#add_substring";
      md5_update context src ofs len
    method add_string src =
      md5_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result =
      md5_final context
    method wipe =
      wipe_bytes context
  end

let md5 () = new md5

class blake2b sz key =
  object(self)
    val context =
      if sz >= 8 && sz <= 512 && sz mod 8 = 0 && String.length key <= 64
      then blake2b_init (sz / 8) key
      else raise (Error Wrong_key_size)
    method hash_size = sz / 8
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "blake2b#add_substring";
      blake2b_update context src ofs len
    method add_string src =
      blake2b_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result = blake2b_final context (sz / 8)
    method wipe =
      wipe_bytes context
  end

let blake2b sz = new blake2b sz ""
let blake2b512 () = new blake2b 512 ""

class blake2s sz key =
  object(self)
    val context =
      if sz >= 8 && sz <= 256 && sz mod 8 = 0 && String.length key <= 32
      then blake2s_init (sz / 8) key
      else raise (Error Wrong_key_size)
    method hash_size = sz / 8
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "blake2s#add_substring";
      blake2s_update context src ofs len
    method add_string src =
      blake2s_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result = blake2s_final context (sz / 8)
    method wipe =
      wipe_bytes context
  end

let blake2s sz = new blake2s sz ""
let blake2s256 () = new blake2s 256 ""

class blake3 key sz =
  object(self)
    val context =
      if sz > 0 && sz mod 8 = 0
      && (String.length key = 0 || String.length key = 32)
      then blake3_init key
      else raise (Error Wrong_key_size)
    method hash_size = sz / 8
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "blake3#add_substring";
      blake3_update context src ofs len
    method add_string src =
      blake3_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result = blake3_final context (sz / 8)
    method wipe = blake3_wipe context
  end

let blake3 sz = new blake3 "" sz
let blake3_256 () = new blake3 "" 256

end

(* High-level entry points for ciphers *)

module Cipher = struct

type direction = dir = Encrypt | Decrypt

type chaining_mode =
    ECB
  | CBC
  | CFB of int
  | OFB of int
  | CTR
  | CTR_N of int

let make_block_cipher ?(mode = CBC) ?pad ?iv dir block_cipher =
  let chained_cipher =
    match (mode, dir) with
      (ECB, _) -> block_cipher
    | (CBC, Encrypt) -> new Block.cbc_encrypt ?iv block_cipher
    | (CBC, Decrypt) -> new Block.cbc_decrypt ?iv block_cipher

    | (CFB n, Encrypt) -> new Block.cfb_encrypt ?iv n block_cipher
    | (CFB n, Decrypt) -> new Block.cfb_decrypt ?iv n block_cipher
    | (OFB n, _) -> new Block.ofb ?iv n block_cipher
    | (CTR, _) -> new Block.ctr ?iv block_cipher
    | (CTR_N n, _) -> new Block.ctr ?iv ~inc:n block_cipher in
  match pad with
    None -> new Block.cipher chained_cipher
  | Some p ->
      match dir with
        Encrypt -> new Block.cipher_padded_encrypt p chained_cipher
      | Decrypt -> new Block.cipher_padded_decrypt p chained_cipher

let normalize_dir mode dir =
  match mode with
  | Some(CFB _) | Some(OFB _) | Some(CTR) | Some(CTR_N _) -> Encrypt
  | _ -> dir

let aes ?mode ?pad ?iv key dir =
  make_block_cipher ?mode ?pad ?iv dir
   (match normalize_dir mode dir with
      Encrypt -> new Block.aes_encrypt key
    | Decrypt -> new Block.aes_decrypt key)

let blowfish ?mode ?pad ?iv key dir =
  make_block_cipher ?mode ?pad ?iv dir
   (match normalize_dir mode dir with
      Encrypt -> new Block.blowfish_encrypt key
    | Decrypt -> new Block.blowfish_decrypt key)

let des ?mode ?pad ?iv key dir =
  make_block_cipher ?mode ?pad ?iv dir
    (new Block.des (normalize_dir mode dir) key)

let triple_des ?mode ?pad ?iv key dir =
  make_block_cipher ?mode ?pad ?iv dir
   (match normalize_dir mode dir with
      Encrypt -> new Block.triple_des_encrypt key
    | Decrypt -> new Block.triple_des_decrypt key)

let arcfour key dir = new Stream.cipher (new Stream.arcfour key)

let chacha20 ?iv ?ctr key dir =
  new Stream.cipher (new Stream.chacha20 key ?iv ?ctr)

end

(* The hmac construction *)

module HMAC(H: sig class h: hash  val blocksize: int end) =
  struct
    let hmac_pad key byte =
      let key =
        if String.length key > H.blocksize
        then hash_string (new H.h) key
        else key in
      let r = Bytes.make H.blocksize (Char.chr byte) in
      xor_string key 0 r 0 (String.length key);
      r
    class hmac key =
      object(self)
        inherit H.h as super
        initializer
          (let b = hmac_pad key 0x36 in
           self#add_substring b 0 (Bytes.length b);
           wipe_bytes b)
        method result =
          let h' = new H.h in
          let b = hmac_pad key 0x5C in
          h'#add_substring b 0 (Bytes.length b);
          wipe_bytes b;
          h'#add_string (super#result);
          let r = h'#result in
          h'#wipe;
          r
      end
  end

(* High-level entry points for MACs *)

module MAC = struct

module HMAC_SHA1 =
  HMAC(struct class h = Hash.sha1  let blocksize = 64 end)
module HMAC_SHA256 =
  HMAC(struct class h = Hash.sha256  let blocksize = 64 end)
module HMAC_SHA384 =
  HMAC(struct class h = Hash.sha384  let blocksize = 128 end)
module HMAC_SHA512 =
  HMAC(struct class h = Hash.sha512  let blocksize = 128 end)
module HMAC_RIPEMD160 = 
  HMAC(struct class h = Hash.ripemd160  let blocksize = 64 end)
module HMAC_MD5 =
  HMAC(struct class h = Hash.md5  let blocksize = 64 end)

let hmac_sha1 key = new HMAC_SHA1.hmac key
let hmac_sha256 key = new HMAC_SHA256.hmac key
let hmac_sha384 key = new HMAC_SHA384.hmac key
let hmac_sha512 key = new HMAC_SHA512.hmac key
let hmac_ripemd160 key = new HMAC_RIPEMD160.hmac key
let hmac_md5 key = new HMAC_MD5.hmac key

let blake2b sz key = new Hash.blake2b sz key
let blake2b512 key = new Hash.blake2b 512 key

let blake2s sz key = new Hash.blake2s sz key
let blake2s256 key = new Hash.blake2s 256 key

let blake3 sz key = new Hash.blake3 key sz
let blake3_256 key = new Hash.blake3 key 256

let aes ?iv ?pad key =
  new Block.mac ?iv ?pad (new Block.aes_encrypt key)
let des ?iv ?pad key =
  new Block.mac ?iv ?pad (new Block.des_encrypt key)
let triple_des ?iv ?pad key =
  new Block.mac ?iv ?pad (new Block.triple_des_encrypt key)
let des_final_triple_des ?iv ?pad key =
  let kl = String.length key in
  if kl <> 16 && kl <> 24 then raise (Error Wrong_key_size);
  let k1 = String.sub key 0 8 in
  let k2 = String.sub key 8 8 in
  let k3 = if kl = 24 then String.sub key 16 8 else k1 in
  let c1 = new Block.des_encrypt k1
  and c2 = new Block.des_decrypt k2
  and c3 = new Block.des_encrypt k3 in
  wipe_string k1; wipe_string k2; wipe_string k3;
  new Block.mac_final_triple ?iv ?pad c1 c2 c3

let aes_cmac ?iv key =
  let cipher = new Block.aes_encrypt key in
  let b = Bytes.make 16 '\000' in
  let l = Bytes.create 16 in
  cipher#transform b 0 l 0;           (* l = AES-128(K, 000...000 *)
  Bytes.set b 15 '\x87';              (* b = the Rb constant *)
  let k1 = Bytes.create 16 in
  shl1_bytes l 0 k1 0 16;
  if Char.code (Bytes.get l 0) land 0x80 > 0 then xor_bytes b 0 k1 0 16;
  let k2 = Bytes.create 16 in
  shl1_bytes k1 0 k2 0 16;
  if Char.code (Bytes.get k1 0) land 0x80 > 0 then xor_bytes b 0 k2 0 16;
  wipe_bytes l;
  new Block.cmac ?iv cipher k1 k2

class siphash sz key =
  object(self)
    val context =
      if String.length key = 16 && (sz = 64 || sz = 128)
      then siphash_init key (sz / 8)
      else raise (Error Wrong_key_size)
    method hash_size = sz / 8
    method add_substring src ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length src - len
      then invalid_arg "siphash#add_substring";
      siphash_update context src ofs len
    method add_string src =
      siphash_update context (Bytes.unsafe_of_string src) 0 (String.length src)
    method add_char c =
      self#add_string (String.make 1 c)
    method add_byte b =
      self#add_char (Char.unsafe_chr b)
    method result = siphash_final context (sz / 8)
    method wipe =
      wipe_bytes context
  end

let siphash key = new siphash 64 key
let siphash128 key = new siphash 128 key

end

(* Authenticated encryption with associated data *)

module AEAD = struct

type direction = dir = Encrypt | Decrypt

(* AES-GCM *)

(* The H multiplier for GHASH is derived from the AES key by
   encrypting the all-zero block. *)

let ghash_multiplier (aes: Block.block_cipher) =
  let b = Bytes.make 16 '\000' in
  aes#transform b 0 b 0;
  ghash_init b

(* Add a block to the rolling MAC.  len must be between 0 and 16.
   If less than 16, we logically pad with zeros at the end,
   i.e. we "xor with zero" (= keep unchanged) the MAC bytes
   between len and 16. *)

let ghash_block h mac buf ofs len =
  xor_bytes buf ofs mac 0 len;
  ghash_mult h mac

let ghash_block_s h mac buf ofs len =
  xor_string buf ofs mac 0 len;
  ghash_mult h mac

(* Hash the given string, with zero padding.  Used for the non-encrypted
   authenticated data and for counter generation. *)

let ghash_string h msg =
  let mac = Bytes.make 16 '\000' in
  let l = String.length msg in
  let i = ref 0 in
  while !i + 16 <= l do
    ghash_block_s h mac msg !i 16;
    i := !i + 16
  done;
  if !i < l then ghash_block_s h mac msg !i (l - !i);
  mac

(* Produce the final authentication tag *)

let ghash_final h mac headerlen cipherlen e0 =
  let buf = Bytes.create 16 in
  (* Hash the extra block containing the lengths *)
  Bytes.set_int64_be buf 0 (Int64.mul headerlen 8L); (* in bits *)
  Bytes.set_int64_be buf 8 (Int64.mul cipherlen 8L); (* in bits *)
  ghash_block h mac buf 0 16;
  (* Authentication tag = final MAC xor encryption of the IV *)
  Bytes.blit mac 0 buf 0 16;
  xor_bytes e0 0 buf 0 16;
  Bytes.to_string buf

(* Initial value of the counter *)

let counter0 h iv =
  if String.length iv = 12 then
    Bytes.of_string (iv ^ "\000\000\000\001")
  else begin
    let mac = ghash_string h iv in
    let buf = Bytes.make 16 '\000' in
    Bytes.set_int64_be buf 8 (Int64.mul (Int64.of_int (String.length iv)) 8L);
    ghash_block h mac buf 0 16;
    mac
  end

(* Encryption of the initial counter *)

let enc_initial_counter (aes: Block.block_cipher) counter0 =
  let b = Bytes.create 16 in
  aes#transform counter0 0 b 0;
  b

(* CTR encryption / decryption *)

let ctr_enc_dec (aes: Block.block_cipher) ctr buf src soff dst doff len =
  Block.increment_counter ctr 12 15;
  aes#transform ctr 0 buf 0;
  xor_bytes src soff buf 0 len;
  Bytes.blit buf 0 dst doff len

class aes_gcm_encrypt ?(header = "") ~iv key =
  (* The AES block cipher *)
  let aes = new Block.aes_encrypt key in
  (* The multiplier for the GHASH MAC *)
  let h = ghash_multiplier aes in
  (* The counter for use in CTR mode. *)
  let ctr = counter0 h iv in
  (* The encryption of the initial counter, to be used for the final MAC *)
  let e0 = enc_initial_counter aes ctr in
  (* The current MAC, initialized with the header
     (the non-encrypted authenticated data) *)
  let mac = ghash_string h header in
  (* Lengths of the authenticated data and the encrypted data *)
  let headerlen = Int64.of_int (String.length header)
  and cipherlen = ref 0L in
  (* A wrapper around the block cipher that 
     - performs encryption in CTR mode
     - updates the MAC
     - updates the length of encrypted data *)
  let enc_wrapped : Block.block_cipher = 
    let buf = Bytes.create 16 in
    object
      method blocksize = 16
      method wipe = aes#wipe
      method transform src soff dst doff =
        ctr_enc_dec aes ctr buf src soff dst doff 16;
        ghash_block h mac dst doff 16;
        cipherlen := Int64.(add !cipherlen 16L);
        if !cipherlen > 0xfffffffe0L then raise (Error Message_too_long)
    end in
  object(self)
    inherit (Block.cipher enc_wrapped)
    method input_block_size = 1
    method output_block_size = 1
    method tag_size = 16
    method finish_and_get_tag =
      if used > 0 then begin
        let buf = Bytes.create 16 in
        (* Encrypt final block *)
        self#ensure_capacity used;
        ctr_enc_dec aes ctr buf ibuf 0 obuf oend used;
        (* Hash final block padded with zeros *)
        ghash_block h mac obuf oend used;
        oend <- oend + used;
        cipherlen := Int64.(add !cipherlen (of_int used));
        if !cipherlen > 0xfffffffe0L then raise (Error Message_too_long)
      end;
      (* Produce authentication tag *)
      ghash_final h mac headerlen !cipherlen e0
  end

class aes_gcm_decrypt ?(header = "") ~iv key =
  (* The AES block cipher *)
  let aes = new Block.aes_encrypt key in
  (* The multiplier for the GHASH MAC *)
  let h = ghash_multiplier aes in
  (* The counter for use in CTR mode. *)
  let ctr = counter0 h iv in
  (* The encryption of the initial counter, to be used for the final MAC *)
  let e0 = enc_initial_counter aes ctr in
  (* The current MAC, initialized with the header
     (the non-encrypted authenticated data) *)
  let mac = ghash_string h header in
  (* Lengths of the authenticated data and the encrypted data *)
  let headerlen = Int64.of_int (String.length header)
  and cipherlen = ref 0L in
  (* A wrapper around the block cipher that 
     - updates the MAC
     - performs decryption in CTR mode
     - updates the length of encrypted data *)
  let dec_wrapped : Block.block_cipher = 
    let buf = Bytes.create 16 in
    object
      method blocksize = 16
      method wipe = aes#wipe
      method transform src soff dst doff =
        ghash_block h mac src soff 16;
        ctr_enc_dec aes ctr buf src soff dst doff 16;
        cipherlen := Int64.(add !cipherlen 16L)
    end in
  object(self)
    inherit (Block.cipher dec_wrapped)
    method input_block_size = 1
    method output_block_size = 1
    method tag_size = 16
    method finish_and_get_tag =
      if used > 0 then begin
        let buf = Bytes.create 16 in
        (* Hash final block padded with zeros *)
        ghash_block h mac ibuf 0 used;
        (* Decrypt final block *)
        self#ensure_capacity used;
        ctr_enc_dec aes ctr buf ibuf 0 obuf oend used;
        oend <- oend + used;
        cipherlen := Int64.(add !cipherlen (of_int used))
      end;
      (* Produce authentication tag *)
      ghash_final h mac headerlen !cipherlen e0
  end

let aes_gcm ?header ~iv key dir =
  match dir with
  | Encrypt -> (new aes_gcm_encrypt ?header ~iv key :> authenticated_transform)
  | Decrypt -> (new aes_gcm_decrypt ?header ~iv key :> authenticated_transform)

(* Chacha20-Poly1305 *)

let poly1305_update_pad h n =
  let n = (0x10 - n) land 0xF in
  if n > 0 then poly1305_update h (Bytes.make n '\000') 0 n

let poly1305_init_hash cha header =
  let buf = Bytes.make 64 '\000' in
  cha#transform buf 0 buf 0 64;
  let h = poly1305_init buf in  (* only the first 32 bytes are used *)
  wipe_bytes buf;
  poly1305_update h (Bytes.unsafe_of_string header) 0 (String.length header);
  (* Pad header to a multiple of 16 bytes *)
  poly1305_update_pad h (String.length header land 0xF);
  h

let poly1305_finish_and_get_tag h headerlen cipherlen =
  (* Pad ciphertext to a multiple of 16 bytes *)
  poly1305_update_pad h Int64.(to_int (logand cipherlen 0xFL));
  (* Add lengths as 64-bit little-endian numbers *)
  let buf = Bytes.create 16 in
  Bytes.set_int64_le buf 0 headerlen;
  Bytes.set_int64_le buf 8 cipherlen;
  poly1305_update h buf 0 16;
  (* The final hash is the authentication tag *)
  poly1305_final h

class chapoly_encrypt ?(header = "") ~iv key =
  (* The Chacha20 stream cipher *)
  let cha = new Stream.chacha20 ~iv key in
  (* The Poly1305 hash *)
  let h = poly1305_init_hash cha header in
  (* Lengths of the authenticated data and the encrypted data *)
  let headerlen = Int64.of_int (String.length header)
  and cipherlen = ref 0L in
  (* Maximum length for encrypted data *)
  let maxlen =
    if String.length iv = 12 then 0x4000000000L else Int64.max_int in
  (* The stream cipher that wraps Chacha20 with hash updates *)
  let enc = object
    method transform src soff dst doff len =
      cha#transform src soff dst doff len;
      poly1305_update h dst doff len;
      cipherlen := Int64.(add !cipherlen (of_int len));
      if !cipherlen > maxlen then raise (Error Message_too_long)
    method wipe =
      cha#wipe; wipe_bytes h
  end in
  object(self)
    inherit (Stream.cipher enc)
    method input_block_size = 1
    method output_block_size = 1
    method tag_size = 16
    method finish_and_get_tag =
      poly1305_finish_and_get_tag h headerlen !cipherlen
  end

class chapoly_decrypt ?(header = "") ~iv key =
  (* The Chacha20 stream cipher *)
  let cha = new Stream.chacha20 ~iv key in
  (* The Poly1305 hash *)
  let h = poly1305_init_hash cha header in
  (* Lengths of the authenticated data and the encrypted data *)
  let headerlen = Int64.of_int (String.length header)
  and cipherlen = ref 0L in
  (* The stream cipher that wraps Chacha20 with hash updates *)
  let enc = object
    method transform src soff dst doff len =
      poly1305_update h src soff len;
      cha#transform src soff dst doff len;
      cipherlen := Int64.(add !cipherlen (of_int len))
    method wipe =
      cha#wipe; wipe_bytes h
  end in
  object(self)
    inherit (Stream.cipher enc)
    method input_block_size = 1
    method output_block_size = 1
    method tag_size = 16
    method finish_and_get_tag =
      poly1305_finish_and_get_tag h headerlen !cipherlen
  end

let chacha20_poly1305 ?header ~iv key dir =
  match dir with
  | Encrypt -> (new chapoly_encrypt ?header ~iv key :> authenticated_transform)
  | Decrypt -> (new chapoly_decrypt ?header ~iv key :> authenticated_transform)

end

(* Random number generation *)

module Random = struct

class type rng =
  object
    method random_bytes: bytes -> int -> int -> unit
    method wipe: unit
  end

let string rng len =
  let res = Bytes.create len in
  rng#random_bytes res 0 len;
  Bytes.unsafe_to_string res

type system_rng_handle
external get_system_rng: unit -> system_rng_handle = "caml_get_system_rng"
external close_system_rng: system_rng_handle -> unit = "caml_close_system_rng"
external system_rng_random_bytes: 
  system_rng_handle -> bytes -> int -> int -> bool
  = "caml_system_rng_random_bytes"

class system_rng =
  object(self)
    val h = get_system_rng ()
    method random_bytes buf ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length buf - len
      then invalid_arg "random_bytes";
      if system_rng_random_bytes h buf ofs len
      then ()
      else raise(Error Entropy_source_closed)
    method wipe =
      close_system_rng h
  end

let system_rng () =
  try new system_rng with Not_found -> raise(Error No_entropy_source)

class device_rng filename =
  object(self)
    val fd = Unix.openfile filename [Unix.O_RDONLY; Unix.O_CLOEXEC] 0
    method random_bytes buf ofs len =
      if len > 0 then begin    
        let n = Unix.read fd buf ofs len in
        if n = 0 then raise(Error Entropy_source_closed);
        if n < len then self#random_bytes buf (ofs + n) (len - n)
      end
    method wipe =
      Unix.close fd
  end

let device_rng filename = new device_rng filename

external hardware_rng_available: unit -> bool = "caml_hardware_rng_available"
external hardware_rng_random_bytes: bytes -> int -> int -> bool = "caml_hardware_rng_random_bytes"

class hardware_rng =
  object
    method random_bytes buf ofs len =
      if ofs < 0 || len < 0 || ofs > Bytes.length buf - len      
      then invalid_arg "hardware_rng#random_bytes";
      if not (hardware_rng_random_bytes buf ofs len)
      then raise (Error Entropy_source_closed)
    method wipe =
      ()
  end

let hardware_rng () =
  if hardware_rng_available ()
  then new hardware_rng
  else raise (Error No_entropy_source)

class no_rng =
  object
    method random_bytes (buf:bytes) (ofs:int) (len:int) : unit = 
      raise (Error No_entropy_source)
    method wipe = ()
  end

let secure_rng =
  try
    new system_rng
  with Not_found ->
  try
    new device_rng "/dev/random"
  with Unix.Unix_error(_,_,_) ->
    if hardware_rng_available ()
    then new hardware_rng
    else new no_rng

class pseudo_rng seed =
  let _ = if String.length seed < 16 then raise (Error Seed_too_short) in
  object (self)
    val ckey =
      let l = String.length seed in
      chacha20_cook_key 
        (if l >= 32 then String.sub seed 0 32
         else if l > 16 then seed ^ String.make (32 - l) '\000'
         else seed)
        (Bytes.make 8 '\000') 0L
    method random_bytes buf ofs len =
      if len < 0 || ofs < 0 || ofs > Bytes.length buf - len
      then invalid_arg "pseudo_rng#random_bytes"
      else chacha20_extract ckey buf ofs len
    method wipe =
      wipe_bytes ckey; wipe_string seed
end

let pseudo_rng seed = new pseudo_rng seed

class pseudo_rng_aes_ctr seed =
  let _ = if String.length seed < 16 then raise (Error Seed_too_short) in
  object (self)
    val cipher = new Block.aes_encrypt (String.sub seed 0 16)
    val ctr = Bytes.make 16 '\000'
    val obuf = Bytes.create 16
    val mutable opos = 16

    method random_bytes buf ofs len =
      if len > 0 then begin
        if opos >= 16 then begin
          (* Encrypt the counter *)
          cipher#transform ctr 0 obuf 0;
          (* Increment the counter *)
          Block.increment_counter ctr 0 15;
          (* We have 16 fresh bytes of pseudo-random data *)
          opos <- 0
        end;
        let r = min (16 - opos) len in
        Bytes.blit obuf opos buf ofs r;
        opos <- opos + r;
        if r < len then self#random_bytes buf (ofs + r) (len - r)
      end

    method wipe =
      wipe_bytes obuf; wipe_string seed
  end

let pseudo_rng_aes_ctr seed = new pseudo_rng_aes_ctr seed

end

(* Key derivation functions *)

module KD = struct

let int2bytes i =
  let b = Bytes.create 4 in
  Bytes.set_int32_be b 0 i;
  Bytes.unsafe_to_string b

let derive fn len ctr =
  let rec deriv accu ctr l =
    if l <= 0 then
      String.sub (String.concat "" (List.rev accu)) 0 len
    else begin
      let s = fn ctr in
      deriv (s :: accu) (Int32.succ ctr) (l - String.length s)
    end
  in deriv [] ctr len

let kdf1 (h: unit -> hash) ?(otherinfo = "") secret len =
  let fn ctr =
    hash_string (h ()) (secret ^ int2bytes ctr ^ otherinfo) in
  derive fn len 0l

let kdf2 (h: unit -> hash) ?(otherinfo = "") secret len =
  let fn ctr =
    hash_string (h ()) (secret ^ int2bytes ctr ^ otherinfo) in
  derive fn len 1l

let kdf3 (h: unit -> hash) ?(otherinfo = "") secret len =
  let fn ctr =
    hash_string (h ()) (int2bytes ctr ^ secret ^ otherinfo) in
  derive fn len 0l

let pbkdf2 (keyed_hash: string -> hash) pwd salt count len =
  let prf s =
    hash_string (keyed_hash pwd) s in
  let rec iterate u r n =
    if n <= 0 then Bytes.to_string r else begin
      let u = prf u in
      xor_string u 0 r 0 (String.length u);
      iterate u r (n - 1)
    end in
  let fn ctr =
    let u = prf (salt ^ int2bytes ctr) in
    let r = Bytes.of_string u in
    iterate u r (count - 1) in
  derive fn len 1l

end


(* RSA operations *)

module RSA = struct

type public_key =
  { size: int;
    n: string;
    e: string
  }

type private_key =
  { size: int;
    n: string;
    d: string;
    p: string;
    q: string;
    dp: string;
    dq: string;
    qinv: string
  }

let wipe_key (k: private_key) =
  wipe_string k.n;
  wipe_string k.d;
  wipe_string k.p;
  wipe_string k.q;
  wipe_string k.dp;
  wipe_string k.dq;
  wipe_string k.qinv

let encrypt (key: public_key) msg =
  let msg = Bn.of_bytes msg in
  let n = Bn.of_bytes key.n in
  let e = Bn.of_bytes key.e in
  if Z.compare msg n >= 0 then raise (Error Message_too_long);
  let r = Z.powm_sec msg e n in
  let s = Bn.to_bytes ~numbits:key.size r in
  Bn.wipe msg; Bn.wipe n; Bn.wipe e; Bn.wipe r;
  s

let unwrap_signature = encrypt

let decrypt (key: private_key) msg =
  let msg = Bn.of_bytes msg in
  let n = Bn.of_bytes key.n in
  let d = Bn.of_bytes key.d in
  if Z.compare msg n >= 0 then raise (Error Message_too_long);
  let r = Z.powm_sec msg d n in
  let s = Bn.to_bytes ~numbits:key.size r in
  Bn.wipe msg; Bn.wipe n; Bn.wipe d; Bn.wipe r;
  s

let sign = decrypt

let decrypt_CRT (key: private_key) msg =
  let msg = Bn.of_bytes msg in
  let n = Bn.of_bytes key.n in
  let p = Bn.of_bytes key.p in
  let q = Bn.of_bytes key.q in
  let dp = Bn.of_bytes key.dp in
  let dq = Bn.of_bytes key.dq in
  let qinv = Bn.of_bytes key.qinv in
  if Z.compare msg n >= 0 then raise (Error Message_too_long);
  let r = Bn.mod_power_CRT msg p q dp dq qinv in
  let s = Bn.to_bytes ~numbits:key.size r in
  Bn.wipe msg; Bn.wipe n; Bn.wipe p; Bn.wipe q;
  Bn.wipe dp; Bn.wipe dq; Bn.wipe qinv; Bn.wipe r;
  s

let sign_CRT = decrypt_CRT

let new_key ?(rng = Random.secure_rng) ?e numbits =
  if numbits < 32 || numbits land 1 > 0 then raise(Error Wrong_key_size);
  let numbits2 = numbits / 2 in
  (* Generate primes p, q with numbits / 2 digits.
     If fixed exponent e, make sure gcd(p-1,e) = 1 and
     gcd(q-1,e) = 1. *)
  let rec gen_factor nbits =
    let n = Bn.random_prime ~rng:(rng#random_bytes) nbits in
    match e with
      None -> n
    | Some e ->
        if Bn.relative_prime (Z.sub n Z.one) (Z.of_int e)
        then n
        else gen_factor nbits in
  (* Make sure p > q *)
  let rec gen_factors nbits =
    let p = gen_factor nbits
    and q = gen_factor nbits in
    let cmp = Z.compare p q in
    if cmp = 0 then gen_factors nbits else
    if cmp < 0 then (q, p) else (p, q) in
  let (p, q) = gen_factors numbits2 in
  (* p1 = p - 1 and q1 = q - 1 *)
  let p1 = Z.sub p Z.one
  and q1 = Z.sub q Z.one in
  (* If no fixed exponent specified, generate random exponent e such that
     gcd(p-1,e) = 1 and gcd(q-1,e) = 1 *)
  let e =
    match e with
      Some e -> Z.of_int e
    | None ->
        let rec gen_exponent () =
          let n = Bn.random ~rng:(rng#random_bytes) numbits in
          if Bn.relative_prime n p1 && Bn.relative_prime n q1
          then n
          else gen_exponent () in
        gen_exponent () in
  (* n = pq *)
  let n = Z.mul p q in
  (* d = e^-1 mod (p-1)(q-1) *)
  let d = Z.invert e (Z.mul p1 q1) in
  (* dp = d mod p-1 and dq = d mod q-1 *)
  let dp = Z.erem d p1 and dq = Z.erem d q1 in
  (* qinv = q^-1 mod p *)
  let qinv = Z.invert q p in
  (* Build key *)
  let priv : private_key =
    { size = numbits;
      n = Bn.to_bytes ~numbits:numbits n;
      d = Bn.to_bytes ~numbits:numbits d;
      p = Bn.to_bytes ~numbits:numbits2 p;
      q = Bn.to_bytes ~numbits:numbits2 q;
      dp = Bn.to_bytes ~numbits:numbits2 dp;
      dq = Bn.to_bytes ~numbits:numbits2 dq;
      qinv = Bn.to_bytes ~numbits:numbits2 qinv }
  and pub : public_key =
    { size = numbits;
      n = Bn.to_bytes ~numbits:numbits n;
      e = Bn.to_bytes ~numbits:numbits e } in
  Bn.wipe n; Bn.wipe e; Bn.wipe d;
  Bn.wipe p; Bn.wipe q;
  Bn.wipe p1; Bn.wipe q1;
  Bn.wipe dp; Bn.wipe dq; Bn.wipe qinv;
  (priv, pub)

end

module Paillier = struct

  type public_key =
  { size: int;
    n: string;
    n2: string;
    g: string
  }

  type private_key =
  { size: int;
    n: string;
    n2: string;
    p: string;
    q: string;
    lambda: string;
    mu: string
  }

  let wipe_key (k: private_key) =
    wipe_string k.n;
    wipe_string k.n2;
    wipe_string k.p;
    wipe_string k.q;
    wipe_string k.lambda;
    wipe_string k.mu

  let encrypt ?(rng = Random.secure_rng) (key: public_key) msg =
    let rec get_r () =
      let r = Bn.random ~rng:(rng#random_bytes) (key.size-1) in
      if Bn.(relative_prime r (Bn.of_bytes key.n))
      && r < Bn.of_bytes key.n
      then Bn.to_bytes r
      else get_r () in
    let r = get_r () in
    let gm = mod_power key.g msg key.n2 in
    let rn = mod_power r key.n key.n2 in
    let c = mod_mult gm rn key.n2 in
    c

  let decrypt (key: private_key) c =
    let c = Bn.of_bytes c in
    let n = Bn.of_bytes key.n in
    let n2 = Bn.of_bytes key.n2 in
    let lambda = Bn.of_bytes key.lambda in
    let mu = Bn.of_bytes key.mu in
    let cn = Bn.mod_power c lambda n2 in
    let lx = Z.((cn - one) / n) in
    let m = Bn.mulm lx mu n in
    let msg = Bn.to_bytes m in
    Bn.wipe c; Bn.wipe n; Bn.wipe n2; Bn.wipe lambda;
    Bn.wipe mu; Bn.wipe cn; Bn.wipe lx; Bn.wipe m;
    msg

  let add (key: public_key) c1 c2 =
    mod_mult c1 c2 key.n2

  let new_key ?(rng = Random.secure_rng) numbits =
    if numbits < 32 || numbits land 1 > 0 then raise(Error Wrong_key_size);
    let numbits2 = numbits / 2 in
    (* Make sure p > q *)
    let rec gen_factors nbits =
      let p = Bn.random_prime ~rng:(rng#random_bytes) nbits
      and q = Bn.random_prime ~rng:(rng#random_bytes) nbits in
      let p1 = Z.(sub p one)
      and q1 = Z.(sub q one) in
      let cmp = Z.compare p q in
      if cmp = 0 then gen_factors nbits else
      if Bn.(relative_prime Z.(p * q) Z.(p1 * q1))
      then (p, q, p1, q1)
      else gen_factors nbits in
    let (p, q, p1, q1) = gen_factors numbits2 in
    (* n = pq *)
    let n = Z.mul p q in
    let n2 = Z.mul n n in
    let g = Z.(add n one) in
    let lambda = Z.lcm p1 q1 in
    let mu = Z.invert lambda n in

    (* Build key *)
    let priv =
      { size = numbits;
        n = Bn.to_bytes ~numbits:numbits n;
        n2 = Bn.to_bytes n2;
        p = Bn.to_bytes ~numbits:numbits2 p;
        q = Bn.to_bytes ~numbits:numbits2 q;
        lambda = Bn.to_bytes lambda;
        mu = Bn.to_bytes mu}
    and pub =
      { size = numbits;
        n = Bn.to_bytes ~numbits:numbits n;
        n2 = Bn.to_bytes n2;
        g = Bn.to_bytes g } in
    Bn.wipe n; Bn.wipe n2; Bn.wipe g;
    Bn.wipe p; Bn.wipe q; Bn.wipe p1; Bn.wipe q1;
    Bn.wipe lambda; Bn.wipe mu;
    (priv, pub)

  end

(* Diffie-Hellman key agreement *)

module DH = struct

type parameters =
  { p: string;
    g: string;
    privlen: int }

let new_parameters ?(rng = Random.secure_rng) ?(privlen = 160) numbits =
  if numbits < 32 || numbits <= privlen then raise(Error Wrong_key_size);
  let np = Bn.random_prime ~rng:(rng#random_bytes) numbits in
  let rec find_generator () =
    let g = Bn.random ~rng:(rng#random_bytes) (numbits - 1) in
    if Z.compare g Z.one <= 0 then find_generator() else g in
  let ng = find_generator () in
  { p = Bn.to_bytes ~numbits np;
    g = Bn.to_bytes ~numbits ng;
    privlen = privlen }

type private_secret = Bn.t

let private_secret ?(rng = Random.secure_rng) params =
  Bn.random ~rng:(rng#random_bytes) params.privlen

let message params privsec =
  Bn.to_bytes ~numbits:(String.length params.p * 8)
    (Z.powm_sec (Bn.of_bytes params.g) privsec (Bn.of_bytes params.p))

let shared_secret params privsec othermsg =
  let res =
    Bn.to_bytes ~numbits:(String.length params.p * 8)
      (Z.powm_sec (Bn.of_bytes othermsg) privsec (Bn.of_bytes params.p))
  in Bn.wipe privsec; res

let derive_key ?(diversification = "") sharedsec numbytes =
  let result = Bytes.create numbytes in
  let rec derive pos counter =
    if pos < numbytes then begin
      let h =
        hash_string (Hash.sha256()) 
                    (diversification ^ sharedsec ^ string_of_int counter) in
      String.blit h 0 result pos (min (String.length h) (numbytes - pos));
      wipe_string h;
      derive (pos + String.length h) (counter + 1)
    end in
  derive 0 1;
  Bytes.unsafe_to_string result

end

(* Base64 encoding *)

module Base64 = struct

let base64_conv_table =
  "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"

class encode multiline padding =
  object (self)
    method input_block_size = 1
    method output_block_size = 1

    inherit buffered_output 256 as output_buffer

    val ibuf = Bytes.create 3
    val mutable ipos = 0
    val mutable ocolumn = 0

    method put_char c =
      Bytes.set ibuf ipos c;
      ipos <- ipos + 1;
      if ipos = 3 then begin
        let b0 = Char.code (Bytes.get ibuf 0)
        and b1 = Char.code (Bytes.get ibuf 1)
        and b2 = Char.code (Bytes.get ibuf 2) in
        self#ensure_capacity 4;
        Bytes.set obuf oend     base64_conv_table.[b0 lsr 2];
        Bytes.set obuf (oend+1) base64_conv_table.[(b0 land 3) lsl 4 + (b1 lsr 4)];
        Bytes.set obuf (oend+2) base64_conv_table.[(b1 land 15) lsl 2 + (b2 lsr 6)];
        Bytes.set obuf (oend+3) base64_conv_table.[b2 land 63];
        oend <- oend + 4;
        ipos <- 0;
        ocolumn <- ocolumn + 4;
        if multiline && ocolumn >= 72 then begin
          self#ensure_capacity 1;
          Bytes.set obuf oend '\n';
          oend <- oend + 1;
          ocolumn <- 0
        end 
      end

    method put_substring s ofs len =
      for i = ofs to ofs + len - 1 do self#put_char (Bytes.get s i) done

    method put_string s =
      String.iter self#put_char s

    method put_byte b = self#put_char (Char.chr b)

    method flush : unit = raise (Error Wrong_data_length)

    method finish =
      begin match ipos with
        1 ->
          self#ensure_capacity 2;
          let b0 = Char.code (Bytes.get ibuf 0) in
          Bytes.set obuf oend     base64_conv_table.[b0 lsr 2];
          Bytes.set obuf (oend+1) base64_conv_table.[(b0 land 3) lsl 4];
          oend <- oend + 2
      | 2 ->
          self#ensure_capacity 3;
          let b0 = Char.code (Bytes.get ibuf 0)
          and b1 = Char.code (Bytes.get ibuf 1) in
          Bytes.set obuf oend     base64_conv_table.[b0 lsr 2];
          Bytes.set obuf (oend+1) base64_conv_table.[(b0 land 3) lsl 4 + (b1 lsr 4)];
          Bytes.set obuf (oend+2) (base64_conv_table.[(b1 land 15) lsl 2]);
          oend <- oend + 3
      | _ -> ()
      end;
      if multiline || padding then begin
        let num_equals =
          match ipos with 1 -> 2 | 2 -> 1 | _ -> 0 in
        self#ensure_capacity num_equals;
        Bytes.fill obuf oend num_equals '=';
        oend <- oend + num_equals
      end;
      if multiline && ocolumn > 0 then begin
        self#ensure_capacity 1;
        Bytes.set obuf oend '\n';
        oend <- oend + 1
      end;
      ocolumn <- 0

    method wipe =
      wipe_bytes ibuf; output_buffer#wipe
  end

let encode_multiline () = new encode true true
let encode_compact () = new  encode false false
let encode_compact_pad () = new encode false true

let base64_decode_char c =
  match c with
    'A' .. 'Z' -> Char.code c - 65
  | 'a' .. 'z' -> Char.code c - 97 + 26
  | '0' .. '9' -> Char.code c - 48 + 52
  | '+' -> 62
  | '/' -> 63
  | ' '|'\t'|'\n'|'\r' -> -1
  | _   -> raise (Error Bad_encoding)

class decode =
  object (self)
    inherit buffered_output 256 as output_buffer

    method input_block_size = 1
    method output_block_size = 1

    val ibuf = Array.make 4 0
    val mutable ipos = 0
    val mutable finished = false

    method put_char c =
      if c = '=' then finished <- true else begin
        let n = base64_decode_char c in
        if n >= 0 then begin
          if finished then raise(Error Bad_encoding);
          ibuf.(ipos) <- n;
          ipos <- ipos + 1;
          if ipos = 4 then begin
            self#ensure_capacity 3;
            Bytes.set obuf oend     (Char.chr(ibuf.(0) lsl 2 + ibuf.(1) lsr 4));
            Bytes.set obuf (oend+1) (Char.chr((ibuf.(1) land 15) lsl 4 + ibuf.(2) lsr 2));
            Bytes.set obuf (oend+2) (Char.chr((ibuf.(2) land 3) lsl 6 + ibuf.(3)));
            oend <- oend + 3;
            ipos <- 0
          end
        end
      end

    method put_substring s ofs len =
      for i = ofs to ofs + len - 1 do self#put_char (Bytes.get s i) done

    method put_string s =
      String.iter self#put_char s

    method put_byte b = self#put_char (Char.chr b)

    method flush : unit = raise (Error Wrong_data_length)

    method finish =
      finished <- true;
      match ipos with
      | 1 -> raise(Error Bad_encoding)
      | 2 ->
          self#ensure_capacity 1;
          Bytes.set obuf oend     (Char.chr(ibuf.(0) lsl 2 + ibuf.(1) lsr 4));
          oend <- oend + 1
      | 3 ->
          self#ensure_capacity 2;
          Bytes.set obuf oend     (Char.chr(ibuf.(0) lsl 2 + ibuf.(1) lsr 4));
          Bytes.set obuf (oend+1) (Char.chr((ibuf.(1) land 15) lsl 4 + ibuf.(2) lsr 2));
          oend <- oend + 2
      | _ -> ()

    method wipe =
      Array.fill ibuf 0 4 0; output_buffer#wipe
  end

let decode () = new decode

end

(* Hexadecimal encoding *)

module Hexa = struct

let hex_conv_table = "0123456789abcdef"

class encode =
  object (self)
    method input_block_size = 1
    method output_block_size = 1

    inherit buffered_output 256 as output_buffer

    method put_byte b =
      self#ensure_capacity 2;
      Bytes.set obuf oend     (hex_conv_table.[b lsr 4]);
      Bytes.set obuf (oend+1) (hex_conv_table.[b land 0xF]);
      oend <- oend + 2

    method put_char c = self#put_byte (Char.code c)

    method put_substring s ofs len =
      for i = ofs to ofs + len - 1 do self#put_char (Bytes.get s i) done

    method put_string s =
      String.iter self#put_char s

    method flush = ()
    method finish = ()

    method wipe = output_buffer#wipe
  end

let encode () = new encode

let hex_decode_char c =
  match c with
  | '0' .. '9' -> Char.code c - 48
  | 'A' .. 'F' -> Char.code c - 65 + 10
  | 'a' .. 'f' -> Char.code c - 97 + 10
  | ' '|'\t'|'\n'|'\r' -> -1
  | _   -> raise (Error Bad_encoding)

class decode =
  object (self)
    inherit buffered_output 256 as output_buffer

    method input_block_size = 1
    method output_block_size = 1

    val ibuf = Array.make 2 0
    val mutable ipos = 0

    method put_char c =
      let n = hex_decode_char c in
      if n >= 0 then begin
        ibuf.(ipos) <- n;
        ipos <- ipos + 1;
        if ipos = 2 then begin
          self#ensure_capacity 1;
          Bytes.set obuf oend (Char.chr(ibuf.(0) lsl 4 lor ibuf.(1)));
          oend <- oend + 1;
          ipos <- 0
        end
      end

    method put_substring s ofs len =
      for i = ofs to ofs + len - 1 do self#put_char (Bytes.get s i) done

    method put_string s =
      String.iter self#put_char s

    method put_byte b = self#put_char (Char.chr b)

    method flush =
      if ipos <> 0 then raise(Error Wrong_data_length)

    method finish =
      if ipos <> 0 then raise(Error Bad_encoding)

    method wipe =
      Array.fill ibuf 0 2 0; output_buffer#wipe
  end

let decode () = new decode

end

(* Compression *)

module Zlib = struct

type stream

type flush_command =
    Z_NO_FLUSH
  | Z_SYNC_FLUSH
  | Z_FULL_FLUSH
  | Z_FINISH

external deflate_init: int -> bool -> stream = "caml_zlib_deflateInit"
external deflate:
  stream -> bytes -> int -> int -> bytes -> int -> int -> flush_command
         -> bool * int * int
  = "caml_zlib_deflate_bytecode" "caml_zlib_deflate"
external deflate_end: stream -> unit = "caml_zlib_deflateEnd"

external inflate_init: bool -> stream = "caml_zlib_inflateInit"
external inflate:
  stream -> bytes -> int -> int -> bytes -> int -> int -> flush_command
         -> bool * int * int
  = "caml_zlib_inflate_bytecode" "caml_zlib_inflate"
external inflate_end: stream -> unit = "caml_zlib_inflateEnd"

class compress level write_zlib_header =
  object(self)
    val zs = deflate_init level write_zlib_header
    
    inherit buffered_output 512 as output_buffer

    method input_block_size = 1
    method output_block_size = 1

    method put_substring src ofs len =
      if len > 0 then begin
        self#ensure_capacity 256;
        let (_, used_in, used_out) =
          deflate zs
                  src ofs len
                  obuf oend (Bytes.length obuf - oend)
                  Z_NO_FLUSH in
        oend <- oend + used_out;
        if used_in < len
        then self#put_substring src (ofs + used_in) (len - used_in)
      end

    method put_string s =
      self#put_substring (Bytes.unsafe_of_string s) 0 (String.length s)

    method put_char c = self#put_string (String.make 1 c)

    method put_byte b = self#put_char (Char.chr b)

    method flush =
      self#ensure_capacity 256;
      let (_, _, used_out) =
         deflate zs
                 (Bytes.unsafe_of_string "") 0 0
                 obuf oend (Bytes.length obuf - oend)
                 Z_SYNC_FLUSH in
      oend <- oend + used_out;
      if oend = Bytes.length obuf then self#flush

    method finish =
      self#ensure_capacity 256;
      let (finished, _, used_out) =
         deflate zs
                 (Bytes.unsafe_of_string "") 0 0
                 obuf oend (Bytes.length obuf - oend)
                 Z_FINISH in
      oend <- oend + used_out;
      if finished then deflate_end zs else self#finish

    method wipe =
      output_buffer#wipe
end

let compress ?(level = 6) ?(write_zlib_header = false) () = new compress level write_zlib_header 

class uncompress expect_zlib_header =
  object(self)
    val zs = inflate_init expect_zlib_header
    
    inherit buffered_output 512 as output_buffer

    method input_block_size = 1
    method output_block_size = 1

    method put_substring src ofs len =
      if len > 0 then begin
        self#ensure_capacity 256;
        let (finished, used_in, used_out) =
          inflate zs
                  src ofs len
                  obuf oend (Bytes.length obuf - oend)
                  Z_SYNC_FLUSH in
        oend <- oend + used_out;
        if used_in < len then begin
          if finished then
            raise(Error(Compression_error("Zlib.uncompress",
               "garbage at end of compressed data")));
          self#put_substring src (ofs + used_in) (len - used_in)
        end
      end

    method put_string s =
      self#put_substring (Bytes.unsafe_of_string s) 0 (String.length s)

    method put_char c = self#put_string (String.make 1 c)

    method put_byte b = self#put_char (Char.chr b)

    method flush = ()

    method finish =
      let rec do_finish first_finish =
        self#ensure_capacity 256;
        let (finished, _, used_out) =
           inflate zs
                   (Bytes.unsafe_of_string " ") 0 (if first_finish then 1 else 0)
                   obuf oend (Bytes.length obuf - oend)
                   Z_SYNC_FLUSH in
        oend <- oend + used_out;
        if not finished then do_finish false in
      do_finish true; inflate_end zs

    method wipe =
      output_buffer#wipe
end

let uncompress ?(expect_zlib_header = false) () = new uncompress expect_zlib_header

end

(* Utilities *)

let xor_bytes src src_ofs dst dst_ofs len =
  if len < 0
  || src_ofs < 0 || src_ofs > Bytes.length src - len
  || dst_ofs < 0 || dst_ofs > Bytes.length dst - len
  then invalid_arg "xor_bytes";
  xor_bytes src src_ofs dst dst_ofs len
  
let xor_string src src_ofs dst dst_ofs len =
  if len < 0
  || src_ofs < 0 || src_ofs > String.length src - len
  || dst_ofs < 0 || dst_ofs > Bytes.length dst - len
  then invalid_arg "xor_string";
  xor_string src src_ofs dst dst_ofs len
 
(* Elliptic curves *)

module type CURVE_PARAMETERS = sig
  (* Weierstrass form  y^2 = x^3 + a x + b *)
  val name: string                       (* curve name *)
  val size: int                          (* bit size *)
  val a: Z.t                             (* curve parameter a *)
  val b: Z.t                             (* curve parameter b *)
  val p: Z.t                             (* curve field (modulus) *)
  val order: Z.t                         (* curve order *)
  val generator: Z.t * Z.t               (* curve generator *)
end

module type ELLIPTIC_CURVE = sig
  module Params: CURVE_PARAMETERS
  type point
  val x: point -> Z.t
  val y: point -> Z.t
  val zero: point
  val generator: point
  val make_point: Z.t * Z.t -> point
  val encode_point: ?compressed:bool -> point -> string
  val decode_point: string -> point
  val add: point -> point -> point
  val neg: point -> point
  val dbl: point -> point
  val mul: Z.t -> point -> point
  val muladd: Z.t -> point -> Z.t -> point -> point
end

module EC (C: CURVE_PARAMETERS): ELLIPTIC_CURVE = struct

module Params = C

let ( +^ ) a b = Z.(erem (a + b) C.p)
let ( -^ ) a b = Z.(erem (a - b) C.p)
let ( *^ ) a b = Z.(erem (a * b) C.p)
let sqrm a = Z.(erem (a * a) C.p)
let invm a = Z.invert a C.p

type point = Z.t * Z.t

let is_on_curve (x, y) =
  y *^ y = (x *^ x +^ C.a) *^ x +^ C.b

let zero = (Z.zero, Z.zero)   (* Point at infinity *)

let make_point p =
  if p = zero || is_on_curve p then p else raise (Error Invalid_point)

let generator = make_point C.generator

let y_recover ~x ~sign =
  let y2 = (x *^ x +^ C.a) *^ x +^ C.b in
  match Bn.sqrtm y2 C.p with
  | None -> raise (Error Invalid_point)
  | Some y -> if Z.testbit y 0 = sign then y else C.p -^ y

let encode_point ?(compressed = false) (x, y) =
  if not compressed then
    "\004" ^ Bn.to_bytes ~numbits:C.size x ^ Bn.to_bytes ~numbits:C.size y
  else
    (if Z.testbit y 0 then "\003" else "\002") ^ Bn.to_bytes ~numbits:C.size x

let decode_point p =
  let nbytes = (C.size + 7) / 8 in
  let l = String.length p in
  if l = 0 then raise (Error Bad_encoding);
  match p.[0] with
  | '\002' | '\003' ->
      if l <> 1 + nbytes then raise (Error Bad_encoding);
      let x = Bn.of_bytes (String.sub p 1 nbytes) in
      let y = y_recover ~x ~sign:(p.[0] = '\003') in
      make_point (x, y)
  | '\004' ->
      if l <> 1 + 2 * nbytes then raise (Error Bad_encoding);
      let x = Bn.of_bytes (String.sub p 1 nbytes)
      and y = Bn.of_bytes (String.sub p (1 + nbytes) nbytes) in
      make_point (x, y)
  | _ ->
      raise (Error Bad_encoding)

let aff2jac (x, y) = (x, y, Z.one)

let jac2aff (x, y, z) =
  if z = Z.zero then zero else begin
    let z_inv = invm z in
    let z_inv2 = z_inv *^ z_inv in
    let z_inv3 = z_inv2 *^ z_inv in
    (x *^ z_inv2, y *^ z_inv3)
  end

let dbl_jac (x, y, z) =
  let xx = x *^ x
  and yy = y *^ y
  and zz = z *^ z in
  let yyyy = yy *^ yy in
  let s = Z.(erem (~$2 * ((x + yy) * (x + yy) - xx - yyyy)) C.p) in
  let m = Z.(erem (~$3 * xx + C.a * zz * zz) C.p) in
  let t = Z.(erem (m * m - ~$2 * s) C.p) in
  let x' = t in
  let y' = Z.(erem (m * (s - t) - ~$8 * yyyy) C.p) in
  let z' = Z.(erem ((y + z) * (y + z) - yy - zz) C.p) in
  (x', y', z')

let add_jac (x1, y1, z1) (x2, y2, z2) =
  let z1z1 = z1 *^ z1
  and z2z2 = z2 *^ z2 in
  let u1 = x1 *^ z2z2
  and u2 = x2 *^ z1z1
  and s1 = y1 *^ z2 *^ z2z2
  and s2 = y2 *^ z1 *^ z1z1 in
  let h = u2 -^ u1 in
  let i = sqrm Z.(~$2 * h) in
  let j = h *^ i in
  let r = Z.(erem (~$2 * (s2 - s1)) C.p) in
  let v = u1 *^ i in
  let x3 = Z.(erem (r * r - j - ~$2 * v) C.p) in
  let y3 = Z.(erem (r * (v - x3) - ~$2 * s1 * j) C.p) in
  let z3 = Z.(erem (sqrm (z1 + z2) - z1z1 - z2z2) C.p) *^ h in
  (x3, y3, z3)

let neg (x, y) = (x, C.p -^ y)

let add p q =
  if p = zero then q else
  if q = zero then p else
  if p = q
  then jac2aff (dbl_jac (aff2jac p))
  else jac2aff (add_jac (aff2jac p) (aff2jac q))
  
let dbl p =
  if p = zero then zero else jac2aff (dbl_jac (aff2jac p))

let mul n p =
  assert (Z.sign n >= 0);
  if p = zero || n = Z.zero then zero else begin
    let x = aff2jac p in
    let rec mul i r =
      if i < 0 then r else
        mul (i - 1)
            (if Z.testbit n i
             then add_jac (dbl_jac r) x
             else dbl_jac r) in
    jac2aff (mul (Z.numbits n - 2) x)
  end

let muladd n p m q =
  assert (Z.sign n >= 0 && Z.sign m >= 0);
  if p = zero || n = Z.zero then mul m q else
  if q = zero || m = Z.zero then mul n p else begin
    let x = aff2jac p
    and y = aff2jac q in
    let xy = add_jac x y in
    let rec mul i r =
      if i < 0 then r else begin
        let r = dbl_jac r in
        let r =
          match Z.testbit n i, Z.testbit m i with
          | false, false -> r
          | true, false  -> add_jac r x
          | false, true  -> add_jac r y
          | true, true   -> add_jac r xy in
        mul (i - 1) r
      end in
    let i = Int.max (Z.numbits n) (Z.numbits m) - 1 in
    let r =
      match Z.testbit n i, Z.testbit m i with
      | false, false -> assert false
      | true, false  -> x
      | false, true  -> y
      | true, true   -> xy in
    jac2aff (mul (i - 1) r)
  end

let x (x, y) = x
let y (x, y) = y

end

module P192 = EC(struct
  let name = "secp192r1"
  let size = 192
  let p = Z.of_string "0xfffffffffffffffffffffffffffffffeffffffffffffffff"
  let a = Z.of_string "0xfffffffffffffffffffffffffffffffefffffffffffffffc"
  let b = Z.of_string "0x64210519e59c80e70fa7e9ab72243049feb8deecc146b9b1"
  let generator = (Z.of_string "0x188da80eb03090f67cbf20eb43a18800f4ff0afd82ff1012", Z.of_string "0x07192b95ffc8da78631011ed6b24cdd573f977a11e794811")
  let order = Z.of_string "0xffffffffffffffffffffffff99def836146bc9b1b4d22831"
end)

module P224 = EC(struct
  let name = "secp224r1"
  let size = 224
  let p = Z.of_string "0xffffffffffffffffffffffffffffffff000000000000000000000001"
  let a = Z.of_string "0xfffffffffffffffffffffffffffffffefffffffffffffffffffffffe"
  let b = Z.of_string "0xb4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"
  let generator = (Z.of_string "0xb70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21", Z.of_string "0xbd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34")
  let order = Z.of_string "0xffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d"
end)

module P256 = EC(struct
  let name = "secp256r1"
  let size = 256
  let p = Z.of_string "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
  let a = Z.of_string "0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc"
  let b = Z.of_string "0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b"
  let generator = (Z.of_string "0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", Z.of_string "0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5")
  let order = Z.of_string "0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
end)

module P384 = EC(struct
  let name = "secp384r1"
  let size = 384
  let p = Z.of_string "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff"
  let a = Z.of_string "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc"
  let b = Z.of_string "0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef"
  let generator = (Z.of_string "0xaa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7", Z.of_string "0x3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f")
  let order = Z.of_string "0xffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973"
end)

module P521 = EC(struct
  let name = "secp521r1"
  let size = 521
  let p = Z.of_string "0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
  let a = Z.of_string "0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc"
  let b = Z.of_string "0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00"
  let generator = (Z.of_string "0x00c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66", Z.of_string "0x011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650")
  let order = Z.of_string "0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409"
end)

module ECDSA (C: ELLIPTIC_CURVE) = struct

let n = C.Params.order

type private_key = Z.t

type public_key = C.point

let wipe_key = CryptokitBignum.wipe

let new_key ?(rng = Random.secure_rng) () =
  let priv = Bn.random_upto ~rng: rng#random_bytes n in
  let pub = C.mul priv C.generator in
  (priv, pub)

let rec sign ?(rng = Random.secure_rng) (s: private_key) msg =
  if String.length msg * 8 > C.Params.size then raise (Error Message_too_long);
  let h = Bn.of_bytes msg in
  let k = Bn.random_upto ~rng: rng#random_bytes n in
  let pt = C.mul k C.generator in
  let i = C.x pt in
  let x = Z.erem i n in
  if x = Z.zero then sign ~rng s msg else begin
    let y = Bn.(divm (addm h (mulm s x n) n) k n) in
    if y = Z.zero
    then sign ~rng s msg
    else (Bn.to_bytes ~numbits:C.Params.size x,
          Bn.to_bytes ~numbits:C.Params.size y)
  end

let verify (q: public_key) (x, y) msg =
  if String.length msg * 8 > C.Params.size then raise (Error Message_too_long);
  let x = Bn.of_bytes x
  and y = Bn.of_bytes y
  and h = Bn.of_bytes msg in
  q <> C.zero && C.mul n q = C.zero &&
  Z.lt Z.zero x && Z.lt x n && Z.lt Z.zero y && Z.lt y n &&
  begin
    let p = C.muladd (Bn.divm h y n) C.generator
                     (Bn.divm x y n) q in
    x = Z.erem (C.x p) n
  end

end

module ECDH (C: ELLIPTIC_CURVE) = struct

type private_secret = Z.t

let private_secret ?(rng = Random.secure_rng) () =
  Bn.random ~rng:(rng#random_bytes) C.Params.size

let message privsec =
  C.encode_point (C.mul privsec C.generator)

let shared_secret privsec othermsg =
  let res =
    C.encode_point (C.mul privsec (C.decode_point othermsg)) in
  Bn.wipe privsec; res

end
