//===========================================================================
// @(#) $Name:$
// @(#) $Id: mcblockc.cc 9417 2017-06-04 03:07:40Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2017
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file mcblockc.cc
//!  \brief mcblock command-line client
//---------------------------------------------------------------------------

extern "C" {
  #include <sys/types.h>
  #include <sys/socket.h>
  #include <netinet/in.h>
  #include <netinet/tcp.h>
  #include <arpa/inet.h>
  #include <netdb.h>
  #include <pwd.h>
  #include <unistd.h>
  #include <fcntl.h>
}

#include <cassert>
#include <fstream>
#include <string>

#include "DwmAuth.hh"
#include "DwmAuthCountedString.hh"
#include "DwmIO.hh"
#include "DwmIpv4Routes.hh"
#include "DwmOptArgs.hh"
#include "DwmSignal.hh"
#include "DwmSocket.hh"
#include "DwmSvnTag.hh"
#include "DwmSysLogger.hh"
#include "DwmAuthPeerAuthenticator.hh"
#include "DwmMcBlockRequestMessage.hh"
#include "DwmMcBlockResponseMessage.hh"

static const Dwm::SvnTag svntag("@(#) $DwmPath: dwm/mcplex/mcblock/tags/mcblock-0.1.0/apps/mcblockc/mcblockc.cc 9417 $");

using namespace std;
using namespace Dwm;

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static string GetHomeDirectory()
{
  string  rc;
  uid_t   uid = geteuid();
  struct passwd *pw = getpwuid(uid);
  if (pw) {
    rc = pw->pw_dir;
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static string GetKeyDirectory()
{
  string  rc;
  string  homeDir = GetHomeDirectory();
  if (! homeDir.empty()) {
    rc = homeDir + "/.dwmauth";
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool Authenticate(Dwm::Socket & s, string & theirId,
                         string & agreedKey)
{
  bool  rc = false;
  Auth::PeerAuthenticator  peerAuth(GetKeyDirectory() + "/id_rsa",
                                    GetKeyDirectory() + "/known_services");
  if (peerAuth.Authenticate(s, theirId, agreedKey)) {
    rc = true;
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
bool RequestResponse(Socket & s, const string & secret,
                     const McBlock::RequestMessage & request,
                     McBlock::ResponseMessage & response)
{
  bool  rc = false;
  if (request.Write(s, secret)) {
    if (response.Read(s, secret)) {
      rc = true;
    }
    else {
      Syslog(LOG_ERR, "Failed to read server response");
    }
  }
  else {
    Syslog(LOG_ERR, "Failed to write server request");
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
bool Connect(const string & host, Socket & s)
{
  bool  rc = false;
  struct hostent  *hostEntry = gethostbyname(host.c_str());
  if (! hostEntry) {
    cerr << "Host '" << host << "' not found.\n";
    exit(1);
  }
  
  if (s.Open(PF_INET, SOCK_STREAM, 0)) {
    int  rcvbufSize = 256 * 1024;
    s.Setsockopt(SOL_SOCKET, SO_RCVBUF, &rcvbufSize, sizeof(rcvbufSize));
      
    struct sockaddr_in  servAddr;
    memset(&servAddr, 0, sizeof(servAddr));
    servAddr.sin_len = sizeof(servAddr);
    servAddr.sin_family = AF_INET;
    servAddr.sin_port = htons(1001);
    servAddr.sin_addr.s_addr = *(in_addr_t *)hostEntry->h_addr_list[0];
    if (s.Connect(servAddr)) {
      int  noDelay = 1;
      s.Setsockopt(IPPROTO_TCP, TCP_NODELAY, &noDelay, sizeof(noDelay));
      rc = true;
    }
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static void Usage(const string & argv0)
{
  cerr << "Usage: " << argv0 << " activate tableName prefix(es)\n"
       << "       " << argv0 << " deactivate tableName prefix(es)\n"
       << "       " << argv0 << " getactive tableName\n"
       << "       " << argv0 << " getaddrules tableName\n"
       << "       " << argv0 << " loghit tableName ipv4addr\n"
       << "       " << argv0 << " search ipv4addr\n";
  return;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool ActivatePrefixes(Socket & s, const string & secret,
                             const string & table,
                             const vector<Ipv4Prefix> & prefixes)
{
  bool  rc = false;
  McBlock::ResponseMessage  resp;
  McBlock::RequestMessage   req(McBlock::ActivateReq(table, prefixes));
  if (RequestResponse(s, secret, req, resp)) {
    if (! resp.Json().empty()) {
      cout << resp.Json().toStyledString() << '\n';
      rc = true;
    }
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool DeactivatePrefixes(Socket & s, const string & secret,
                               const string & table,
                               const vector<Ipv4Prefix> & prefixes)
{
  bool  rc = false;
  McBlock::ResponseMessage  resp;
  McBlock::RequestMessage   req(McBlock::DeactivateReq(table, prefixes));
  if (RequestResponse(s, secret, req, resp)) {
    if (! resp.Json().empty()) {
      cout << resp.Json().toStyledString() << '\n';
      rc = true;
    }
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool GetActivePrefixes(Socket & s, const string & secret,
                              const string & table)
{
  bool  rc = false;
  McBlock::ResponseMessage  resp;
  McBlock::RequestMessage   req((McBlock::GetActiveReq(table)));
  if (RequestResponse(s, secret, req, resp)) {
    if (! resp.Json().empty()) {
      if (resp.Json().isObject()) {
        auto  && members = resp.Json().getMemberNames();
        for (auto tblName : members) {
          map<string,Ipv4Routes<uint8_t>>  countryRoutes;
          if (resp.Json()[tblName].isMember("prefixes")
              && resp.Json()[tblName]["prefixes"].isArray()) {
            cout << tblName << ":\n";
            const Json::Value &  pfxArr = resp.Json()[tblName]["prefixes"];
            for (uint32_t pn = 0; pn < pfxArr.size(); ++pn) {
              Ipv4Prefix  prefix(pfxArr[pn]["prefix"].asString());
              string      country(pfxArr[pn]["country"].asString());
              cout << setiosflags(ios::left)
                   << "  "
                   << setw(18) << pfxArr[pn]["prefix"].asString() << ' '
                   << setiosflags(ios::right)
                   << setw(6) << pfxArr[pn]["daysRemaining"].asUInt() << "d "
                   << setw(3) << pfxArr[pn]["country"].asString() << " ("
                   << pfxArr[pn]["countryName"].asString() << ")\n"
                   << resetiosflags(ios::right) << resetiosflags(ios::left);
              countryRoutes[country][prefix] = 1;
            }
          }
          vector<pair<string,uint64_t>>  countryAddrSpace;
          for (auto & cre : countryRoutes) {
            countryAddrSpace.push_back(pair<string,uint64_t>(cre.first,
                                                             cre.second.AddressesCovered()));
          }
          sort (countryAddrSpace.begin(), countryAddrSpace.end(),
                [&] (const pair<string,uint64_t> & a1,
                     const pair<string,uint64_t> & a2)
                { return (a1.second > a2.second); });
          cout << "\n  Addresses covered per country:\n";
          // std::locale::global(std::locale(""));
          std::cout.imbue(std::locale("en_US.UTF-8"));
          for (auto & cas : countryAddrSpace) {
            cout << "    " << cas.first << ' ' << cas.second << '\n';
            auto  cre = countryRoutes[cas.first];
            vector<pair<uint8_t,uint32_t>>  hmSizes;
            cre.HashSizes(hmSizes);
            for (auto hms : hmSizes) {
              cout << "      /" << setw(2) << (uint16_t)hms.first
                   << " networks: " << setw(4) << hms.second << " ("
                   << hms.second * (1UL << (32 - hms.first))
                   << " addresses)\n";
            }
          }
        }
      }
      rc = true;
    }
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool Search(Socket & s, const string & secret,
                   const Ipv4Address & addr)
{
  bool                      rc = false;
  McBlock::ResponseMessage  resp;
  McBlock::RequestMessage   req((McBlock::SearchReq(addr)));
  if (RequestResponse(s, secret, req, resp)) {
    if (! resp.Json().empty()) {
      if (resp.Json().isObject()) {
        auto  && members = resp.Json().getMemberNames();
        for (auto tblName : members) {
          if (resp.Json()[tblName].isMember("prefix")
              && resp.Json()[tblName]["prefix"].isString()) {
            const Json::Value  & match = resp.Json()[tblName];
            cout << tblName << ":\n"
                 << setiosflags(ios::left)
                 << "  "
                 << setw(18) << match["prefix"].asString()
                 << ' ' << setiosflags(ios::right)
                 << setw(6) << match["daysRemaining"].asUInt() << "d "
                 << setw(3) << match["country"].asString() << " ("
                 << match["countryName"].asString() << ")\n"
                 << resetiosflags(ios::right) << resetiosflags(ios::left);
          }
          rc = true;
        }
      }
    }
  }
  return rc;
}

            
//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool GetAddRules(Socket & s, const string & secret,
                        const string & table)
{
  bool  rc = false;
  McBlock::ResponseMessage  resp;
  McBlock::RequestMessage   req((McBlock::GetAddRulesReq(table)));
  if (RequestResponse(s, secret, req, resp)) {
    if (! resp.Json().empty()) {
      cout << resp.Json().toStyledString() << '\n';
      rc = true;
    }
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
static bool LogHit(Socket & s, const string & secret,
                   const string & table, const string & addrstr)
{
  Dwm::TimeValue64  now(true);
  bool  rc = false;
  Ipv4Address  addr(addrstr);
  McBlock::ResponseMessage  resp;
  McBlock::RequestMessage   req((McBlock::LogHitReq(table, addr, now)));
  if (RequestResponse(s, secret, req, resp)) {
    if (! resp.Json().empty()) {
      cout << resp.Json().toStyledString() << '\n';
      rc = true;
    }
  }
  return rc;
}

//----------------------------------------------------------------------------
//!  
//----------------------------------------------------------------------------
int main(int argc, char *argv[])
{
  int      rc = 1;
  string   mcblockdHost;
  OptArgs  optargs;
  optargs.AddOptArg("h:", "host", false, "", "server host");
  
  int  nextArg = optargs.Parse(argc, argv);

  if (argc < 2) {
    Usage(argv[0]);
    exit(1);
  }
  
  SysLogger::Open("mcblockc", LOG_PERROR, LOG_USER);
  SysLogger::MinimumPriority("info");

  mcblockdHost = optargs.Get<string>('h');
  if (mcblockdHost.empty()) {
    if (getenv("MCBLOCKDHOST") != nullptr) {
      mcblockdHost = getenv("MCBLOCKDHOST");
    }
    else {
      cerr << "mcblockd host required (or MCBLOCKDHOST set in environment)\n";
      Usage(argv[0]);
      exit(1);
    }
  }

  Signal  sigPipe(SIGPIPE);
  sigPipe.Block();
  
  Socket  s;
  string  theirId, secret;
  if (Connect(mcblockdHost, s)) {
    if (! Authenticate(s, theirId, secret)) {
      cerr << "Failed to authenticate!\n";
      return 1;
    }
  }
  else {
    cerr << "Failed to connect to " << mcblockdHost << '\n';
    return 1;
  }
  
  if (string(argv[nextArg]) == "activate") {
    if (argc > ++nextArg) {
      string  table = argv[nextArg];
      if (argc > ++nextArg) {
        vector<Ipv4Prefix>  prefixes;
        for ( ; nextArg < argc; ++nextArg) {
          prefixes.push_back(Ipv4Prefix(argv[nextArg]));
        }
        if (ActivatePrefixes(s, secret, table, prefixes)) {
          rc = 0;
        }
      }
      else {
        Usage(argv[0]);
        exit(1);
      }
    }
    else {
      Usage(argv[0]);
      exit(1);
    }
  }
  else if (string(argv[nextArg]) == "deactivate") {
    if (argc > ++nextArg) {
      string  table = argv[nextArg];
      if (argc > ++nextArg) {
        vector<Ipv4Prefix>  prefixes;
        for ( ; nextArg < argc; ++nextArg) {
          prefixes.push_back(Ipv4Prefix(argv[nextArg]));
        }
        if (DeactivatePrefixes(s, secret, table, prefixes)) {
          rc = 0;
        }
      }
      else {
        Usage(argv[0]);
        exit(1);
      }
    }
    else {
      Usage(argv[0]);
      exit(1);
    }
  }
  else if (string(argv[nextArg]) == "search") {
    if (argc > ++nextArg) {
      if (Search(s, secret, Ipv4Address(argv[nextArg]))) {
        rc = 0;
      }
    }
    else {
      Usage(argv[0]);
      exit(1);
    }
  }
  else if (string(argv[nextArg]) == "getactive") {
    string  table;
    if (argc > ++nextArg) {
      table = argv[nextArg];
    }
    if (GetActivePrefixes(s, secret, table)) {
      rc = 0;
    }
    else {
      Usage(argv[0]);
      exit(1);
    }
  }
  else if (string(argv[nextArg]) == "getaddrules") {
    if (argc > ++nextArg) {
      string  table = argv[nextArg];
      if (GetAddRules(s, secret, table)) {
        rc = 0;
      }
    }
    else {
      Usage(argv[0]);
      exit(1);
    }
  }
  else if (string(argv[nextArg]) == "loghit") {
    if (argc > ++nextArg) {
      string  table = argv[nextArg];
      if (argc > ++nextArg) {
        string  hostaddr = argv[nextArg];
        if (LogHit(s, secret, table, hostaddr)) {
          rc = 0;
        }
      }
      else {
        Usage(argv[0]);
        exit(1);
      }
    }
    else {
      Usage(argv[0]);
      exit(1);
    }
  }
  
  else {
    Usage(argv[0]);
    exit(1);
  }

  return rc;
}
