//===========================================================================
// @(#) $Name:$
// @(#) $Id: TestDwmAuthPeerAuthenticator.cc 10973 2020-08-25 02:48:22Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2017
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file TestDwmAuthPeerAuthenticator.cc
//!  \brief Unit tests for Dwm::Auth::PeerAuthenticator
//---------------------------------------------------------------------------

extern "C" {
  #include <sys/types.h>
  #include <sys/socket.h>
  #include <netinet/in.h>
  #include <netinet/tcp.h>
  #include <arpa/inet.h>
}

#include <thread>
#include <cryptopp/osrng.h>
#include <cryptopp/eccrypto.h>
#include <cryptopp/oids.h>

#include "DwmSocket.hh"
#include "DwmSvnTag.hh"
#include "DwmSysLogger.hh"
#include "DwmTimeValue64.hh"
#include "DwmUnitAssert.hh"
#include "DwmAuth.hh"
#include "DwmAuthSymCrypto.hh"
#include "DwmAuthCountedString.hh"
#include "DwmAuthECDHAgreement.hh"
#include "DwmAuthPeerAuthenticator.hh"

static const Dwm::SvnTag  svntag("$DwmPath: dwm/libDwmAuth/tags/libDwmAuth-0.3.5/tests/TestDwmAuthPeerAuthenticator.cc 10973 $");

using namespace std;
using namespace Dwm;

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void ClientThread(bool preloadKeys)
{
  Socket  s;
  if (UnitAssert(s.Open(PF_INET, SOCK_STREAM, 0))) {
    int  on = 1;
    struct sockaddr_in  sockAddr;
    memset(&sockAddr, 0, sizeof(sockAddr));
#ifndef __linux__
    sockAddr.sin_len = sizeof(sockAddr);
#endif
    sockAddr.sin_family = AF_INET;
    sockAddr.sin_port = htons(2525);
    sockAddr.sin_addr.s_addr = inet_addr("127.0.0.1");
    if (UnitAssert(s.Connect(sockAddr))) {
      if (preloadKeys) {
        string                   myId;
        Auth::Ed25519::KeyPair   myKeys;
        if (UnitAssert(myKeys.Load("./id_ed25519", myId))) {
          Auth::PeerAuthenticator  peerAuth(myId, myKeys,
                                            "./known_services");
          string                   theirid, agreedKey;
          UnitAssert(peerAuth.Authenticate(s, theirid, agreedKey));
          UnitAssert(! theirid.empty());
          UnitAssert(! agreedKey.empty());
        }
      }
      else {
        Auth::PeerAuthenticator  peerAuth("./id_ed25519", "./known_services");
        string                   theirid, agreedKey;
        UnitAssert(peerAuth.Authenticate(s, theirid, agreedKey));
        UnitAssert(! theirid.empty());
        UnitAssert(! agreedKey.empty());
      }
    }
    UnitAssert(s.Close());
  }
  return;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void ClientThreadASIO(bool preloadKeys)
{
  namespace ip = boost::asio::ip;
  using boost::asio::io_context;

  io_context         ioContext;
  ip::tcp::socket    sck(ioContext);
  ip::tcp::endpoint  endPoint(ip::address::from_string("127.0.0.1"), 2525);
  boost::system::error_code  ec;
  sck.connect(endPoint, ec);
  if (UnitAssert(! ec)) {
    if (preloadKeys) {
      string                   myId;
      Auth::Ed25519::KeyPair   myKeys;
      if (UnitAssert(myKeys.Load("./id_ed25519", myId))) {
        Auth::PeerAuthenticator  peerAuth(myId, myKeys,
                                          "./known_services");
        string                   theirid, agreedKey;
        UnitAssert(peerAuth.Authenticate(sck, theirid, agreedKey));
        UnitAssert(! theirid.empty());
        UnitAssert(! agreedKey.empty());
      }
    }
    else {
      Auth::PeerAuthenticator  peerAuth("./id_ed25519", "./known_services");
      string                   theirid, agreedKey;
      UnitAssert(peerAuth.Authenticate(sck, theirid, agreedKey));
      UnitAssert(! theirid.empty());
      UnitAssert(! agreedKey.empty());
    }
    sck.close();
  }
  return;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void ServerThread(bool preloadKeys, std::atomic<bool> & ready)
{
  Socket  s;
  if (UnitAssert(s.Open(PF_INET, SOCK_STREAM, 0))) {
    int  on = 1;
    s.Setsockopt(SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on));
    s.Setsockopt(SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
    struct sockaddr_in  sockAddr;
    memset(&sockAddr, 0, sizeof(sockAddr));
#ifndef __linux__
    sockAddr.sin_len = sizeof(sockAddr);
#endif
    sockAddr.sin_family = AF_INET;
    sockAddr.sin_port = htons(2525);
    sockAddr.sin_addr.s_addr = INADDR_ANY;
    if (UnitAssert(s.Bind(sockAddr))) {
      if (UnitAssert(s.Listen(10))) {
        ready = true;
        Socket  as;
        struct sockaddr_in  clientAddr;
        if (UnitAssert(s.Accept(as, clientAddr))) {
          if (preloadKeys) {
            string                   myId;
            Auth::Ed25519::KeyPair  myKeys;
            if (UnitAssert(myKeys.Load("./id_ed25519", myId))) {
              Auth::PeerAuthenticator  peerAuth(myId, myKeys,
                                                "./authorized_keys");
              string                   theirid, agreedKey;
              UnitAssert(peerAuth.Authenticate(as, theirid, agreedKey));
              UnitAssert(! theirid.empty());
              UnitAssert(! agreedKey.empty());
              UnitAssert(as.Close());
            }
          }
          else {
            Auth::PeerAuthenticator  peerAuth("./id_ed25519", "./authorized_keys");
            string                   theirid, agreedKey;
            UnitAssert(peerAuth.Authenticate(as, theirid, agreedKey));
            UnitAssert(! theirid.empty());
            UnitAssert(! agreedKey.empty());
            UnitAssert(as.Close());
          }
        }
      }
    }
    UnitAssert(s.Close());
  }
  ready = false;
  return;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
void ServerThreadASIO(bool preloadKeys, std::atomic<bool> & ready)
{
  namespace ip = boost::asio::ip;
  
  boost::asio::io_context  ioContext;
  ip::tcp::acceptor
    acc(ioContext,
        ip::tcp::endpoint(ip::address::from_string("127.0.0.1"), 2525),
        true);
  boost::asio::socket_base::reuse_address option(true);
  acc.set_option(option);
  boost::system::error_code  ec;
  ip::tcp::socket            sck(ioContext);
  ip::tcp::endpoint          endPoint;
  ready = true;
  acc.accept(sck, endPoint, ec);
  if (UnitAssert(! ec)) {
    if (preloadKeys) {
      string                   myId;
      Auth::Ed25519::KeyPair  myKeys;
      if (UnitAssert(myKeys.Load("./id_ed25519", myId))) {
        Auth::PeerAuthenticator  peerAuth(myId, myKeys,
                                          "./authorized_keys");
        string                   theirid, agreedKey;
        UnitAssert(peerAuth.Authenticate(sck, theirid, agreedKey));
        UnitAssert(! theirid.empty());
        UnitAssert(! agreedKey.empty());
        sck.close();
      }
    }
    else {
      Auth::PeerAuthenticator  peerAuth("./id_ed25519", "./authorized_keys");
      string                   theirid, agreedKey;
      UnitAssert(peerAuth.Authenticate(sck, theirid, agreedKey));
      UnitAssert(! theirid.empty());
      UnitAssert(! agreedKey.empty());
      sck.close();
    }
  }
  return;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
int main(int argc, char *argv[])
{
  // SysLogger::Open("TestDwmAuthPeerAuthenticator", LOG_PERROR, LOG_USER);

  vector<pair<bool,bool>> preloadKeys = {
    { false, false },
    { false, true },
    { true, true },
    { true, false }
  };

  std::atomic<bool>  serverReady = false;
  for (auto preload : preloadKeys) {
    thread  serverThread(ServerThread, preload.first, std::ref(serverReady));
    while (! serverReady) {}
    thread  clientThread(ClientThread, preload.second);
    clientThread.join();
    serverThread.join();

    serverReady = false;
    serverThread = std::thread(ServerThread, preload.first,
                               std::ref(serverReady));
    while (! serverReady) {}
    clientThread = std::thread(ClientThreadASIO, preload.second);
    clientThread.join();
    serverThread.join();

    serverReady = false;
    serverThread = std::thread(ServerThreadASIO, preload.first,
                               std::ref(serverReady));
    while (! serverReady) {}
    clientThread = std::thread(ClientThreadASIO, preload.second);
    clientThread.join();
    serverThread.join();

    serverReady = false;
    serverThread = std::thread(ServerThreadASIO, preload.first,
                               std::ref(serverReady));
    while (! serverReady) {}
    clientThread = std::thread(ClientThread, preload.second);
    clientThread.join();
    serverThread.join();
  }
  
  int  rc;
  if (Assertions::Total().Failed()) {
    Assertions::Print(std::cerr, true);
    rc = 1;
  }
  else {
    std::cout << Assertions::Total() << " passed\n" << std::flush;
    rc = 0;
  }
  return rc;
}

        
  
