//===========================================================================
// @(#) $Name:$
// @(#) $Id:$
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2020
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file TestXChaCha20Streams.cc
//!  \brief Unit tests for XChaCha20Poly1305::Istream and
//!    XChaCha20Poly1305::Ostream
//---------------------------------------------------------------------------

#include <chrono>
#include <fstream>
#include <iostream>

#include <boost/asio.hpp>

#include "DwmStreamIO.hh"
#include "DwmUnitAssert.hh"
#include "DwmXChaCha20Poly1305Istream.hh"
#include "DwmXChaCha20Poly1305Ostream.hh"

static std::atomic<bool>  g_serverStarted = false;
static std::atomic<bool>  g_serverShouldRun = true;

using namespace Dwm;

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void ServerThread(const std::string & key, const std::string & plaintext)
{
  using namespace  boost::asio;
  using namespace  CryptoPP;

  io_context  ioContext;
  boost::system::error_code  ec;
  ip::tcp::endpoint  endPoint(ip::address::from_string("127.0.0.1"), 7789);
  ip::tcp::acceptor  acc(ioContext, endPoint);
  boost::asio::ip::tcp::acceptor::reuse_address option(true);
  acc.set_option(option, ec);
  acc.non_blocking(true, ec);

  ip::tcp::socket    socket(ioContext);
  ip::tcp::endpoint  client;
  ip::tcp::iostream  stream;
  g_serverStarted = true;
  while (g_serverShouldRun) {
    acc.accept(socket, client, ec);
    if (ec != boost::asio::error::would_block) {
      break;
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(200));
  }
  if (! ec) {
    ip::tcp::iostream  stream(std::move(socket));
    Dwm::XChaCha20Poly1305::Ostream  os(stream, key);
    for (int i = 0; i < 10; ++i) {
      if (! UnitAssert(Dwm::StreamIO::Write(os, plaintext))) {
        break;
      }
      UnitAssert(os.flush());
    }
    stream.close();
  }
  return;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void TestInsertion()
{
  using namespace std;

  const string  key = {
    (char)0x80, (char)0x81, (char)0x82, (char)0x83,
    (char)0x84, (char)0x85, (char)0x86, (char)0x87,
    (char)0x88, (char)0x89, (char)0x8a, (char)0x8b,
    (char)0x8c, (char)0x8d, (char)0x8e, (char)0x8f,
    (char)0x90, (char)0x91, (char)0x92, (char)0x93,
    (char)0x94, (char)0x95, (char)0x96, (char)0x97,
    (char)0x98, (char)0x99, (char)0x9a, (char)0x9b,
    (char)0x9c, (char)0x9d, (char)0x9e, (char)0x9f
  };

  ifstream  plainifs("./random4meg.dat", ios::in | ios::binary);
  if (UnitAssert(plainifs)) {
    plainifs.seekg(0, ios::end);
    string    plainstr(plainifs.tellg(), '\0');
    plainifs.seekg(0, ios::beg);
    plainifs.read(&plainstr[0], plainstr.size());
    plainifs.close();

    if (UnitAssert(! plainstr.empty())) {
      stringstream  ss;
      XChaCha20Poly1305::Ostream  xos(ss, key);
      auto  startTime = chrono::system_clock::now();
      if (UnitAssert(xos << plainstr << flush)) {
        auto  endTime = chrono::system_clock::now();
        auto  elapsedms =
          chrono::duration_cast<chrono::milliseconds>(endTime - startTime);
        cerr << "Encrypted " << plainstr.size() << " bytes in "
             << elapsedms.count() << " milliseconds ("
             << ((plainstr.size() * 1000.0) / elapsedms.count()) / (1024 * 1024)
             << " MB/sec) " << __FILE__ << ':' << __LINE__ << '\n';
        
        XChaCha20Poly1305::Istream  xis(ss, key);
        stringstream  recoveredss;
        startTime = chrono::system_clock::now();
        if (UnitAssert(recoveredss << xis.rdbuf())) {
          endTime = chrono::system_clock::now();
          elapsedms =
            chrono::duration_cast<chrono::milliseconds>(endTime - startTime);
          cerr << "Decrypted " << recoveredss.tellp() << " bytes in "
               << elapsedms.count() << " milliseconds  ("
               << ((recoveredss.tellp() * 1000.0) / elapsedms.count()) / (1024 * 1024)
               << " MB/sec) " << __FILE__ << ":" << __LINE__ << '\n';
          UnitAssert(recoveredss.str() == plainstr);
        }
      }
    }
  }
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void TestReadWrite4M()
{
  using namespace std;

  const string  key = {
    (char)0x80, (char)0x81, (char)0x82, (char)0x83,
    (char)0x84, (char)0x85, (char)0x86, (char)0x87,
    (char)0x88, (char)0x89, (char)0x8a, (char)0x8b,
    (char)0x8c, (char)0x8d, (char)0x8e, (char)0x8f,
    (char)0x90, (char)0x91, (char)0x92, (char)0x93,
    (char)0x94, (char)0x95, (char)0x96, (char)0x97,
    (char)0x98, (char)0x99, (char)0x9a, (char)0x9b,
    (char)0x9c, (char)0x9d, (char)0x9e, (char)0x9f
  };

  ifstream  plainifs("./random4meg.dat", ios::in | ios::binary);
  if (UnitAssert(plainifs)) {
    plainifs.seekg(0, ios::end);
    string    plainstr(plainifs.tellg(), '\0');
    plainifs.seekg(0, ios::beg);
    plainifs.read(&plainstr[0], plainstr.size());
    plainifs.close();
    if (UnitAssert(! plainstr.empty())) {
      stringstream  ss;
      XChaCha20Poly1305::Ostream  xos(ss, key);
      auto  startTime = chrono::system_clock::now();
      UnitAssert(Dwm::StreamIO::Write(xos, plainstr));
      UnitAssert(xos.flush());
      auto  endTime = chrono::system_clock::now();
      auto  elapsedms =
        chrono::duration_cast<chrono::milliseconds>(endTime - startTime);
      cerr << "Encrypted " << plainstr.size() << " bytes in "
           << elapsedms.count() << " milliseconds ("
           << ((plainstr.size() * 1000.0) / elapsedms.count()) / (1024 * 1024)
           << " MB/sec)\n";
      
      XChaCha20Poly1305::Istream  xis(ss, key);
      string  recoveredstr;
      startTime = chrono::system_clock::now();
      if (UnitAssert(Dwm::StreamIO::Read(xis, recoveredstr))) {
        endTime = chrono::system_clock::now();
        elapsedms =
          chrono::duration_cast<chrono::milliseconds>(endTime - startTime);
        cerr << "Decrypted " << recoveredstr.size() << " bytes in "
             << elapsedms.count() << " milliseconds ("
             << ((recoveredstr.size() * 1000.0) / elapsedms.count()) / (1024 * 1024)
             << " MB/sec)\n";
        UnitAssert(recoveredstr == plainstr);
      }
    }
  }
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
int main()
{
  using CryptoInStream  = Dwm::XChaCha20Poly1305::Istream;
  using CryptoOutStream = Dwm::XChaCha20Poly1305::Ostream;
  
  std::string    fileContents;
  std::ifstream  is("TestXChaCha20Streams.cc", std::ios::in | std::ios::binary);
  if (UnitAssert(is)) {
    is.seekg(0, std::ios::end);
    fileContents.resize(is.tellg());
    is.seekg(0, std::ios::beg);
    is.read(&fileContents[0], fileContents.size());
    is.close();
  }

  const std::string  key = {
    (char)0x80, (char)0x81, (char)0x82, (char)0x83,
    (char)0x84, (char)0x85, (char)0x86, (char)0x87,
    (char)0x88, (char)0x89, (char)0x8a, (char)0x8b,
    (char)0x8c, (char)0x8d, (char)0x8e, (char)0x8f,
    (char)0x90, (char)0x91, (char)0x92, (char)0x93,
    (char)0x94, (char)0x95, (char)0x96, (char)0x97,
    (char)0x98, (char)0x99, (char)0x9a, (char)0x9b,
    (char)0x9c, (char)0x9d, (char)0x9e, (char)0x9f
  };

  if (UnitAssert(! fileContents.empty())) {
    for (int test = 0; test < 5; ++test) {
      g_serverStarted = false;
      g_serverShouldRun = true;
      std::thread  serverThread(ServerThread, key, fileContents);
      while (! g_serverStarted) { }
      
      boost::asio::ip::tcp::iostream  tcpStream("127.0.0.1", "7789");
      if (UnitAssert(tcpStream.socket().is_open())) {
        Dwm::XChaCha20Poly1305::Istream  xis(tcpStream, key);
        std::string  recoveredString;
        for (int i = 0; i < 10; ++i) {
          if (UnitAssert(Dwm::StreamIO::Read(xis, recoveredString))) {
            UnitAssert(recoveredString == fileContents);
          }
        }
        tcpStream.close();
      }
      g_serverShouldRun = false;
      serverThread.join();
    }
  }

  TestInsertion();
  TestReadWrite4M();
  
  int  rc;
  if (Assertions::Total().Failed()) {
    Assertions::Print(std::cerr, true);
    rc = 1;
  }
  else {
    std::cout << Assertions::Total() << " passed" << std::endl;
    rc = 0;
  }
  return rc;
}
