//===========================================================================
// @(#) $Name:$
// @(#) $Id: DwmAuthPeerAuthenticator.cc 11414 2020-12-11 08:36:51Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2017, 2020
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file DwmAuthPeerAuthenticator.cc
//!  \brief Dwm::Auth::PeerAuthenticator class implementation
//---------------------------------------------------------------------------

extern "C" {
  #include <sys/types.h>
  #include <sys/socket.h>
  #include <sys/un.h>
  #include <netinet/in.h>
  #include <netinet/tcp.h>
  #include <arpa/inet.h>
  #include <pwd.h>
  #include <unistd.h>
}

#include <cassert>
#include <sstream>
#include <cryptopp/cryptlib.h>
#include <cryptopp/rsa.h>
#include <cryptopp/pssr.h>
#include <cryptopp/secblock.h>

#include "DwmDescriptorIO.hh"
#include "DwmIpAddress.hh"
#include "DwmStreamIO.hh"
#include "DwmIOUtils.hh"
#include "DwmSvnTag.hh"
#include "DwmSysLogger.hh"
#include "DwmAuthPeerAuthenticator.hh"
#include "DwmAuthSymCrypto.hh"
#include "DwmAuthUtils.hh"

static const Dwm::SvnTag svntag("@(#) $DwmPath: dwm/libDwmAuth/tags/libDwmAuth-0.3.14/src/DwmAuthPeerAuthenticator.cc 11414 $");

using namespace std;
using CryptoPP::PSS;
using CryptoPP::SHA256;
using CryptoPP::RSASS;
using CryptoPP::StringSource;
using CryptoPP::StringSink;
using CryptoPP::SignerFilter;

namespace Dwm {

  namespace Auth {

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static bool GetPeerName(int fd, IpAddress & addr, uint16_t & port)
    {
      bool  rc = false;
      addr = IpAddress();
      port = 0;
      if (0 <= fd) {
        struct sockaddr_storage  clientAddr;
        socklen_t socklen = sizeof(clientAddr);
        if (getpeername(fd, (struct sockaddr *)&clientAddr, &socklen) == 0) {
          if (AF_INET == clientAddr.ss_family) {
            struct sockaddr_in  *sinAddr = (struct sockaddr_in *)&clientAddr;
            addr = Ipv4Address(sinAddr->sin_addr.s_addr);
            port = ntohs(sinAddr->sin_port);
            rc = true;
          }
          else if (AF_INET6 == clientAddr.ss_family) {
            struct sockaddr_in6  *sin6Addr = (struct sockaddr_in6 *)&clientAddr;
            addr = Ipv6Address(sin6Addr->sin6_addr);
            port = ntohs(sin6Addr->sin6_port);
            rc = true;
          }
        }
        else {
          Syslog(LOG_ERR, "getpeername(%d, %p, %p) failed: %m",
                 fd, &clientAddr, &socklen);
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static string GetPeerName(int fd)
    {
      string  rc;
      IpAddress  addr;
      uint16_t   port;
      if (GetPeerName(fd, addr, port)) {
        rc = (string)addr + ":";
        rc += to_string(port);
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static string to_str(const CryptoPP::SecByteBlock & sbb)
    {
      return string((const char *)sbb.BytePtr(), sbb.SizeInBytes());
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static CryptoPP::SecByteBlock to_SecByteBlock(const string & s)
    {
      return CryptoPP::SecByteBlock((const uint8_t *)s.c_str(), s.size());
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    PeerAuthenticator::PeerAuthenticator(const string & myPrivKeyPath,
                                         const string & pubKeysPath)
        : _pubKeysPath(pubKeysPath)
    {
      std::string  privKeyPath = myPrivKeyPath;
      
      struct passwd  *pw = 0;
      if (privKeyPath.empty() || _pubKeysPath.empty()) {
        uid_t   uid = getuid();
        pw = getpwuid(uid);
      }
      if (privKeyPath.empty()) {
        if (pw) {
          privKeyPath = pw->pw_dir;
          privKeyPath += "/.dwmauth/id_ed25519";
        }
      }
      if (_pubKeysPath.empty()) {
        if (pw) {
          _pubKeysPath = pw->pw_dir;
          _pubKeysPath += "/.dwmauth/known_services";
        }
      }
      assert(_myKeys.Load(privKeyPath, _myId));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    PeerAuthenticator::PeerAuthenticator(const std::string & myId,
                                         const Ed25519::KeyPair & keyPair,
                                         const std::string & pubKeysPath)
        : _myId(myId), _myKeys(keyPair), _pubKeysPath(pubKeysPath)
    {
      if (_pubKeysPath.empty()) {
        uid_t           uid = getuid();
        struct passwd  *pw = getpwuid(uid);
        if (pw) {
          _pubKeysPath = pw->pw_dir;
          _pubKeysPath += "/.dwmauth/known_services";
        }
      }
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool PeerAuthenticator::ECExchange(int fd, const std::string & peerName,
                                       Auth::ECDHAgreement & agreement)
    {
      bool  rc = false;
      if (fd >= 0) {
        string  mypubEC = to_str(agreement.Public());
        if (DescriptorIO::Write(fd, mypubEC)
            == IOUtils::StreamedLength(mypubEC)) {
          string   theirPubECDH;
          if (Utils::WaitForBytesReady(fd, 36, chrono::milliseconds(5000))) {
            ssize_t  bytesRead =
              Utils::ReadLengthRestrictedString(fd, theirPubECDH, 32);
            if (bytesRead > 0) {
              agreement.TheirPublic(to_SecByteBlock(theirPubECDH));
              if (agreement.Agree()) {
                rc = true;
              }
              else {
                Syslog(LOG_ERR, "Failed to reach agreement with %s on fd %d",
                       peerName.c_str(), fd);
              }
            }
            else {
              Syslog(LOG_ERR, "Read from %s on fd %d failed: %m",
                     peerName.c_str(), fd);
            }
          }
          else {
            Syslog(LOG_ERR, "Peer %s on fd %d took to long to send EC exchange",
                   peerName.c_str(), fd);
          }
        }
        else {
          Syslog(LOG_ERR, "Write to %s on fd %d failed: %m",
                 peerName.c_str(), fd);
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static bool WriteString(boost::asio::ip::tcp::socket & sck,
                            const std::string & str)
    {
      using boost::asio::write, boost::asio::buffer,
        boost::system::error_code;

      bool        rc = false;
      uint32_t    len = str.size();
      uint32_t    nwlen = htonl(len);
      error_code  ec;
      if ((write(sck, buffer(&nwlen, sizeof(nwlen)), ec) == sizeof(nwlen))
          && (! ec)) {
        if (len) {
          if ((write(sck, buffer(str.data(), len), ec) == len) && (! ec)) {
            rc = true;
          }
        }
        else {
          rc = true;
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static bool ReadString(boost::asio::ip::tcp::socket & sck,
                           std::string & str)
    {
      using boost::asio::read, boost::system::error_code, boost::asio::buffer;
      
      str.clear();
      bool        rc = false;
      uint32_t    nwlen;
      error_code  ec;
      if ((read(sck, buffer(&nwlen, sizeof(nwlen)), ec) == sizeof(nwlen))
          && (! ec)) {
        uint32_t  len = htonl(nwlen);
        if (len) {
          try {
            str.resize(len);
            if ((read(sck, buffer(str.data(), str.size()), ec)
                 == str.size())
                && (! ec)) {
              rc = true;
            }
          }
          catch (...) {
            Syslog(LOG_ERR, "Exception resizing string to %u characters",
                   len);
            str.clear();
          }
        }
        else {
          rc = true;
        }
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool PeerAuthenticator::ECExchange(boost::asio::ip::tcp::socket & s,
                                       Auth::ECDHAgreement & agreement)
    {
      bool  rc = false;
      string  mypubEC = to_str(agreement.Public());
      if (WriteString(s, mypubEC)) {
        string  theirPubECDH;
        if (Utils::WaitForBytesReady(s, 36, chrono::milliseconds(5000))) {
          if (Utils::ReadLengthRestrictedString(s, theirPubECDH, 32) > 0) {
            agreement.TheirPublic(to_SecByteBlock(theirPubECDH));
            if (agreement.Agree()) {
              rc = true;
            }
            else {
              Syslog(LOG_ERR, "Failed to reach agreement");
            }
          }
          else {
            Syslog(LOG_ERR, "Read() failed");
          }
        }
        else {
          Syslog(LOG_ERR, "Peer took to long to send EC exchange");
        }
      }
      else {
        Syslog(LOG_ERR, "Write() failed");
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    string PeerAuthenticator
    ::BuildIDAndSignature(const ECDHAgreement & agreement)
    {
      string  agreedKeyStr(to_str(agreement.AgreedKey()));
      string  keystr(agreedKeyStr.substr(0, 16));
      string  ivstr(agreedKeyStr.substr(16, 16));
      string  myIdCrypted = SymCrypto::Encrypt(keystr, ivstr, _myId);
      string  msgToSign(to_str(agreement.Public()) + myIdCrypted);
      CryptoPP::ByteQueue  bt;
      _myKeys.Private().Save(bt, true);

      CryptoPP::ed25519::Signer  signer(bt);
      string  mySignature;
      CryptoPP::AutoSeededRandomPool  rng;
      StringSource ss(msgToSign, true,
                      new SignerFilter(rng, signer,
                                       new StringSink(mySignature)));
      ostringstream  os;
      StreamIO::Write(os, myIdCrypted);
      StreamIO::Write(os, mySignature);
      return os.str();
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool
    PeerAuthenticator::IDAndSigExchange(int fd,
                                        const std::string & peerName,
                                        const ECDHAgreement & agreement,
                                        string & theirId,
                                        string & theirIdCrypted,
                                        string & theirSig)
    {
      bool  rc = false;
      if (fd >= 0) {
        string idAndSig = BuildIDAndSignature(agreement);
        if (DescriptorIO::Write(fd, idAndSig)) {
          string  s;
          auto  endTime = chrono::system_clock::now() + chrono::milliseconds(5000);
          if (Utils::ReadLengthRestrictedString(fd, s, 2048, endTime) > 0) {
            string  keystr((to_str(agreement.AgreedKey())).substr(0, 16));
            string  ivstr((to_str(agreement.AgreedKey())).substr(16, 16));
            istringstream  is(s);
            if (Utils::ReadLengthRestrictedString(is, theirIdCrypted, 2044)) {
              theirId = SymCrypto::Decrypt(keystr, ivstr, theirIdCrypted);
              if (Utils::ReadLengthRestrictedString(is, theirSig, 2044)) {
                rc = true;
              }
              else {
                Syslog(LOG_ERR, "Failed to read signature from %s on fd %d",
                       peerName.c_str(), fd);
              }
            }
            else {
              Syslog(LOG_ERR, "Failed to read encrypted ID from %s on fd %d",
                     peerName.c_str(), fd);
            }
          }
          else {
            Syslog(LOG_ERR, "Failed to read ID and signature from %s on fd %d",
                   peerName.c_str(), fd);
          }
        }
        else {
          Syslog(LOG_ERR, "Failed to write ID and signature to %s on fd %d",
                 peerName.c_str(), fd);
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool PeerAuthenticator
    ::IDAndSigExchange(boost::asio::ip::tcp::socket & sck,
                       const ECDHAgreement & agreement,
                       string & theirId,
                       string & theirIdCrypted,
                       string & theirSig)
    {
      bool  rc = false;
      string idAndSig = BuildIDAndSignature(agreement);
      if (WriteString(sck, idAndSig)) {
        string  s;
        auto  endTime = chrono::system_clock::now() + chrono::milliseconds(5000);
        if (Utils::ReadLengthRestrictedString(sck, s, 2048, endTime) > 0) {
          string  keystr((to_str(agreement.AgreedKey())).substr(0, 16));
          string  ivstr((to_str(agreement.AgreedKey())).substr(16, 16));
          istringstream  is(s);
          if (StreamIO::Read(is, theirIdCrypted)) {
            theirId = SymCrypto::Decrypt(keystr, ivstr, theirIdCrypted);
            if (StreamIO::Read(is, theirSig)) {
              rc = true;
            }
            else {
              Syslog(LOG_ERR, "Failed to read signature");
            }
          }
          else {
            Syslog(LOG_ERR, "Failed to read encrypted ID");
          }
        }
        else {
          Syslog(LOG_ERR, "Failed to read ID and signature");
        }
      }
      else {
        Syslog(LOG_ERR, "Failed to write ID and signature");
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool PeerAuthenticator::Authenticate(int fd, string & theirId,
                                         string & agreedKey)
    {
      bool  rc = false;
      theirId.clear();
      agreedKey.clear();
      if (fd >= 0) {
        string  peerName = GetPeerName(fd);
        int origTcpNoDelay;
        socklen_t  tcpNoDelayLen = sizeof(origTcpNoDelay);
        getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&origTcpNoDelay,
                   &tcpNoDelayLen);
        int tcpNoDelay = 1;
        setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&tcpNoDelay,
                   sizeof(tcpNoDelay));
        Syslog(LOG_DEBUG, "Authenticating connection from %s on fd %d",
               peerName.c_str(), fd);
        Auth::ECDHAgreement  agreement;
        if (ECExchange(fd, peerName, agreement)) {
          string  theirIdCrypted, theirSig;
          if (IDAndSigExchange(fd, peerName, agreement, theirId,
                               theirIdCrypted, theirSig)) {
            if (Verify(theirId, theirIdCrypted, theirSig,
                       to_str(agreement.TheirPublic()))) {
              agreedKey = to_str(agreement.AgreedKey());
              rc = true;
            }
          }
          else {
            Syslog(LOG_ERR, "ID and signature exhange with %s on fd %d failed",
                   peerName.c_str(), fd);
          }
        }
        else {
          Syslog(LOG_ERR, "ECDH exhange with %s on fd %d failed",
                 peerName.c_str(), fd);
        }
        setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&origTcpNoDelay,
                   sizeof(origTcpNoDelay));
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    static bool
    GetSetSocketNoDelayOption(boost::asio::ip::tcp::socket & sck,
                              boost::asio::ip::tcp::no_delay & orig,
                              boost::asio::ip::tcp::no_delay & noDelay)
    {
      bool  rc = false;
      boost::system::error_code  ec;
      sck.get_option(orig, ec);
      if (! ec) {
        sck.set_option(noDelay, ec);
        if (! ec) {
          rc = true;
        }
        else {
          Syslog(LOG_ERR, "sck.set_option(no_delay) failed");
        }
      }
      else {
        Syslog(LOG_ERR, "sck.get_option(no_delay) failed");
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool
    PeerAuthenticator::Verify(const string & theirId,
                              const string & theirIdCrypted,
                              const string & theirSig,
                              const string & theirPublic) const
    {
      bool  rc = false;
      Auth::PublicKeysFile  pubKeysFile;
      pubKeysFile.Load(_pubKeysPath);
      Ed25519::PublicKey  theirPubKey;
      if (pubKeysFile.Find(theirId, theirPubKey)) {
        string  msgToVerify = theirPublic + theirIdCrypted;
        CryptoPP::ByteQueue  bt;
        theirPubKey.Save(bt);
        CryptoPP::ed25519::Verifier  verifier(bt);
        if (verifier.VerifyMessage((const uint8_t *)msgToVerify.c_str(),
                                   msgToVerify.size(),
                                   (const uint8_t *)theirSig.c_str(),
                                   theirSig.size())) {
          rc = true;
        }
        else {
          Syslog(LOG_ERR, "Failed to verify message from %s",
                 theirId.c_str());
        }
      }
      else {
        Syslog(LOG_ERR, "Failed to find %s in %s",
               theirId.c_str(), _pubKeysPath.c_str());
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool PeerAuthenticator::
    Authenticate(boost::asio::ip::tcp::socket & sck, string & theirId,
                 string & agreedKey)
    {
      bool  rc = false;
      theirId.clear();
      agreedKey.clear();
      
      boost::asio::ip::tcp::no_delay  origNoDelayOption;
      boost::asio::ip::tcp::no_delay  noDelayOption(true);
      if (! GetSetSocketNoDelayOption(sck, origNoDelayOption, noDelayOption)) {
        Syslog(LOG_ERR, "Failed to get/set socket no_delay option");
        return false;
      }
      Auth::ECDHAgreement  agreement;
      if (ECExchange(sck, agreement)) {
        string  theirIdCrypted, theirSig;
        if (IDAndSigExchange(sck, agreement, theirId,
                             theirIdCrypted, theirSig)) {
          if (Verify(theirId, theirIdCrypted, theirSig,
                     to_str(agreement.TheirPublic()))) {
            agreedKey = to_str(agreement.AgreedKey());
            rc = true;
          }
        }
        else {
          Syslog(LOG_ERR, "ID and signature exhange failed");
        }
      }
      else {
        Syslog(LOG_ERR, "ECDH exhange failed");
      }

      if (! GetSetSocketNoDelayOption(sck, noDelayOption, origNoDelayOption)) {
        Syslog(LOG_ERR, "Failed to get/set socket no_delay option");
        rc = false;
      }
      return rc;
    }

  }  // namespace Auth

}  // namespace Dwm
