;; This file is part of eris-cl.
;; Copyright (C) 2022 Piotr SzarmaƄski

;; eris-cl is free software: you can redistribute it and/or modify it under the
;; terms of the GNU Lesser General Public License as published by the Free
;; Software Foundation, either version 3 of the License, or (at your option) any
;; later versqion.

;; eris-cl is distributed in the hope that it will be useful, but WITHOUT ANY
;; WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
;; A PARTICULAR PURPOSE. See the GNU General Public License for more details.

;; You should have received a copy of the GNU General Public License along with
;; eris-cl. If not, see <https://www.gnu.org/licenses/>.

(in-package :eris)

(defun split-list-equally (list parts)
  (let* ((len (length list))
         (mod (mod len parts))
         (base (/ (- len mod) parts)))
    (if (< len parts)
        (map 'list #'list list)
        (loop with pos = 0
              for i from (1- parts) downto 0
              collecting (subseq
                          list
                          pos
                          (if (<= mod i)
                              (setf pos (+ pos base))
                              (setf pos (+ pos base 1))))))))

(defun mem-write-vector (vector ptr &optional (offset 0) (count (length vector)))
  (declare (type (simple-array (unsigned-byte 8)) vector)
           (type fixnum offset count))
  (declare (optimize ;; (speed 3) (safety 0) (space 0)
                     (debug 3)))
  (loop for i below count
        for off from offset
        do (setf (cffi:mem-ref ptr :unsigned-char off) (aref vector i))))

(defclass reference-pair+ (reference-pair)
  ((index :initarg :index :accessor index :type (integer 0 32768))))

(defun map-over-key-references (function block)
  (loop for i from 0 to (1- (/ (length block) 64))
        for key-ref = (octets-to-reference-pair (subseq-shared block (* 64 i)))
        until (key-reference-null? key-ref)
        do (funcall function key-ref i)))

(defun decode-blocks (reference-pair-list level block-capacity fetch-function output-file cache-capacity last-block)
  (lambda ()
    (mmap:with-mmap (addr fd size output-file :open :write :protection :write :mmap :shared)
      (let ((get-block (cached-lambda (:cache-class 'lru-cache 
                                       :capacity cache-capacity
                                       :table (make-hash-table :size (1+ cache-capacity) :test #'equalp))
                           (reference key &optional nonce)
                         (let* ((block (execute-fetch-function fetch-function reference)))
                           (unless block (error 'missing-block :reference reference))
                           (hash-check block reference)
                           (decrypt-block block key nonce))))
            (nonce-array (initialize-nonce-array level)))
        (labels ((descend (level reference-pair block-id)
                   (let ((block (funcall get-block (reference reference-pair) (key reference-pair) (aref nonce-array level))))
                     (if (zerop level)
                         (if (= last-block block-id)
                             (mem-write-vector block addr (* 64 block-capacity block-id) (unpad-block block))
                             (mem-write-vector block addr (* 64 block-capacity block-id)))
                         ;; (bordeaux-threads:with-lock-held (lock)
                         ;;   (file-position stream (* 64 block-capacity block-id))
                         ;;   (write-sequence block stream))
                         (map-over-key-references
                          (lambda (key-ref i)
                            (descend (1- level) key-ref (+ i (* block-capacity block-id))))
                          block)))))
          (mapc (lambda (key-ref)
                  (descend level key-ref (index key-ref)))
                reference-pair-list))))))

(defun eris-decode-parallel (read-capability fetch-function output-file
                             &key (cache-capacity 4096) (threads 4) (initial-bindings bordeaux-threads:*default-special-bindings*))
  "Decode an ERIS READ-CAPABILITY in parallel using THREADS threads into a file
designated by OUTPUT-FILE.

Fetch-function must be a function with one argument, the reference octet, which
returns a (simple-array (unsigned-byte 8)) containing the block. The block will
be destructively modified, so you MUST provide a fresh array every time. In
addition, the function MUST be thread-safe.

CACHE-CAPACITY indicates the total amount of blocks stored for all threads. Each
thread has its own cache.

INITIAL-BINDINGS is passed to make-thread. This is only useful if you are
locally binding a special variable to some value."
  (declare (type read-capability read-capability)
           (type function fetch-function)
           (type integer cache-capacity))
  (with-slots (level block-size root-reference-pair) read-capability
    (let ((root (decrypt-block (execute-fetch-function fetch-function (reference root-reference-pair))
                               (key root-reference-pair)
                               (make-nonce level))))
      (when (> level 0) (hash-check root (key root-reference-pair)))
      (case level
        (0 (with-open-file (file output-file :direction :output :element-type '(unsigned-byte 8))
             (write-sequence root file :end (unpad-block root))))
        (t (let* ((initial-list
                    (loop for i from 0 to (/ block-size 64)
                          for key-ref = (octets-to-reference-pair (subseq-shared root (* 64 i)))
                          until (key-reference-null? key-ref)
                          collect key-ref))
                  (list (split-list-equally
                         (loop for i from 0 to (1- (length initial-list))
                               collecting (change-class (elt initial-list i) 'reference-pair+ :index i))
                         threads))
                  ;; (lock (bordeaux-threads:make-lock "stream-lock"))
                  (eof (find-eof root
                                 (lambda (reference key nonce)
                                   (let* ((block (execute-fetch-function fetch-function reference)))
                                     (unless block (error 'missing-block :reference reference))
                                     (hash-check block reference)
                                     (decrypt-block block key nonce)))
                                 block-size
                                 level)))
             (let ((fd (osicat-posix:creat output-file #o666)))
               (osicat-posix:posix-fallocate fd 0 eof)
               (osicat-posix:close fd))
             (map 'nil #'bordeaux-threads:join-thread
                  (map 'list (lambda (reference-pairs)
                               (bordeaux-threads:make-thread
                                (decode-blocks reference-pairs
                                               (1- level)
                                               (/ block-size 64)
                                               fetch-function
                                               output-file
                                               (truncate (/ cache-capacity threads))
                                               (truncate (/ eof block-size)))
                                :initial-bindings initial-bindings))
                       list))))))))