//===========================================================================
// @(#) $Name:$
// @(#) $Id: DwmMcBlockResponder.cc 9443 2017-06-06 00:26:17Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2017
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file DwmMcBlockResponder.cc
//!  \brief Dwm::McBlock::Responder class implementation
//---------------------------------------------------------------------------

extern "C" {
  #include <sys/types.h>
  #include <sys/ioctl.h>
  #include <sys/socket.h>
  #include <netinet/in.h>
  #include <netinet/tcp.h>
  #include <arpa/inet.h>
  #include <netdb.h>
#ifndef __APPLE__
  #include <pthread_np.h>
#endif
}

#include "DwmSysLogger.hh"
#include "DwmSvnTag.hh"
#include "DwmAuth.hh"
#include "DwmAuthPeerAuthenticator.hh"
#include "DwmMcBlockPortability.hh"
#include "DwmMcBlockRequestMessage.hh"
#include "DwmMcBlockResponseMessage.hh"
#include "DwmMcBlockServer.hh"
#include "DwmMcBlockTcpDrop.hh"

static const Dwm::SvnTag  svntag("$DwmPath: dwm/mcplex/mcblock/tags/mcblock-0.1.0/apps/mcblockd/DwmMcBlockResponder.cc 9443 $");

using namespace std;

namespace Dwm {

  namespace McBlock {

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Responder::Responder(Server & server)
        : _socket(), _theirId(), _server(server)
    {
      if (_server.GetSocket().Accept(_socket, _clientAddr)) {
        Syslog(LOG_INFO, "Accepted connection from %s:%hu",
               inet_ntoa(_clientAddr.sin_addr), ntohs(_clientAddr.sin_port));
        _run = true;
        _thread = std::thread(&Responder::Run, this);
        _running = true;
      }
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::Stop()
    {
      _run = false;
      return Join();
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::Join()
    {
      bool  rc = false;
      if (_thread.joinable() && (! _running)) {
        _thread.join();
        rc = true;
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::IsValidActivateRequest(const Json::Value & request)
    {
      return ((request.isMember("type"))
              && (request["type"].isString())
              && (request["type"] == "activate")
              && (request.isMember("table"))
              && (request["table"].isString())
              && (request.isMember("prefixes"))
              && (request["prefixes"].isArray())
              && (request["prefixes"].size() > 0)
              && (request["prefixes"][0].isString()));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::IsValidDeactivateRequest(const Json::Value & request)
    {
      return ((request.isMember("type"))
              && (request["type"].isString())
              && (request["type"].asString() == "deactivate")
              && (request.isMember("table"))
              && (request["table"].isString())
              && (request.isMember("prefixes"))
              && (request["prefixes"].isArray())
              && (request["prefixes"].size() > 0)
              && (request["prefixes"][0].isString()));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::IsValidGetActiveRequest(const Json::Value & request)
    {
      return ((! request.empty())
              && (request.isMember("type"))
              && (request["type"].isString())
              && (request["type"].asString() == "getActive")
              && (request.isMember("table"))
              && (request["table"].isString()));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::IsValidSearchRequest(const Json::Value & request)
    {
      return ((! request.empty())
              && (request.isMember("type"))
              && (request["type"].isString())
              && (request["type"].asString() == "search")
              && (request.isMember("ipv4addr"))
              && (request["ipv4addr"].isString()));
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::IsValidGetAddRulesRequest(const Json::Value & request)
    {
      return ((! request.empty())
              && (request.isMember("type"))
              && (request["type"].isString())
              && (request["type"].asString() == "getAddRules")
              && (request.isMember("table"))
              && (request["table"].isString()));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::IsValidLogHitRequest(const Json::Value & request)
    {
      return ((! request.empty())
              && (request.isMember("type"))
              && (request["type"].isString())
              && (request["type"].asString() == "logHit")
              && (request.isMember("table"))
              && (request["table"].isString())
              && (request.isMember("ipv4addr"))
              && (request["ipv4addr"].isString())
              && (request.isMember("logTime"))
              && (request["logTime"].isUInt()));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    RDAP::FetchedEntryMap
    Responder::GetRDAPEntries(const vector<Ipv4Address> & addrs)
    {
      RDAP::FetchedEntryMap  rc;
      RDAP::Fetcher          fetcher;
      if (fetcher.OpenSession(_server.Config().DwmRDAPServer(),
                              _server.Config().DwmRDAPPrivateKeyFile(),
                              _server.Config().KnownServicesFile())) {
        rc = fetcher.GetEntries(addrs);
        fetcher.CloseSession();
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Json::Value Responder::HandleActivateRequest(const Json::Value & request)
    {
      Json::FastWriter  fw;
      Syslog(LOG_INFO, "request %s", fw.write(request).c_str());
      Json::Value  rc;
      vector<Ipv4Prefix>   prefixes;
      vector<Ipv4Address>  addrs;
      
      Pf::Table  &&table = _server.GetPfDevice().GetTable("", request["table"].asString());
      if (table.Name() == request["table"].asString()) {
        for (uint32_t i = 0; i < request["prefixes"].size(); ++i) {
          Ipv4Prefix  prefix(request["prefixes"][i].asString());
          prefixes.push_back(prefix);
          if (prefix.MaskLength() == 32) {
            addrs.push_back(prefix.FirstAddress());
          }
        }
        RDAP::FetchedEntryMap &&rdapEntries = GetRDAPEntries(addrs);
        for (auto & rde : rdapEntries) {
          Syslog(LOG_INFO, "RDAP country %s prefix %s",
                 rde.second.Country().c_str(),
                 rde.second.Prefix().ToShortString().c_str());
        }
        for (auto ai = 0; ai < prefixes.size(); ++ai) {
          Ipv4Prefix  match(prefixes[ai]);
          if (! table.Contains(prefixes[ai], match)) {
            if (table.Add(match)) {
              KillSourceState(match);
              TcpDropPrefix(match);
              AddToDatabase(table.Name(), prefixes[ai]);
              rc["table"] = request["table"];
              rc["prefixes"][ai]["requested"] = request["prefixes"][ai];
              rc["prefixes"][ai]["added"] = match.ToShortString();
            }
          }
          else {
            rc["table"] = request["table"];
            rc["prefixes"][ai]["requested"] = request["prefixes"][ai];
            rc["prefixes"][ai]["matched"] = match.ToShortString();
          }
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Json::Value Responder::HandleDeactivateRequest(const Json::Value & request)
    {
      Json::Value  rc;
      Pf::Table  &&table = _server.GetPfDevice().GetTable("", request["table"].asString());
      if (table.Name() == request["table"].asString()) {
        for (uint32_t i = 0; i < request["prefixes"].size(); ++i) {
          Ipv4Prefix  prefix(request["prefixes"][i].asString());
          Ipv4Prefix  match(prefix);
          if (table.Contains(prefix, match)) {
            if (match == prefix) {
              if (table.Remove(prefix)) {
                RemoveFromDatabase(table.Name(), prefix);
                rc["table"] = request["table"];
                rc["prefixes"][i]["deactivated"] = request["prefixes"][i];
              }
            }
            else {
              rc["table"] = request["table"];
              rc["prefixes"][i]["not_deactivated"] = match.ToShortString();
            }
          }
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Json::Value Responder::HandleGetActiveRequest(const Json::Value & request)
    {
      Json::Value  rc;
      string       tableName = request["table"].asString();
      auto  dbi = _server.GetDatabases().begin();
      if (! tableName.empty()) {
        dbi = _server.GetDatabases().find(tableName);
      }
      
      for ( ; dbi != _server.GetDatabases().end(); ++dbi) {
        vector<DbEntry>  dbEntries;
        dbi->second.GetAllActive(dbEntries);
        if (! dbEntries.empty()) {
          TimeValue  now(true);
          uint32_t  i = 0;
          for (const auto & entry : dbEntries) {
            rc[dbi->first]["prefixes"][i]["prefix"] =
              entry.Prefix().ToShortString();
            rc[dbi->first]["prefixes"][i]["country"] = entry.Country();
            rc[dbi->first]["prefixes"][i]["countryName"] =
              entry.GetCountryName(entry.Country());
            TimeValue  timeLeft(entry.Interval().End());
            timeLeft -= now;
            uint32_t  daysLeft = (timeLeft.Secs() / (24 * 60 * 60));
            rc[dbi->first]["prefixes"][i]["daysRemaining"] = daysLeft;
            ++i;
          }
        }
        if (! tableName.empty()) {
          break;
        }
      }
      
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Json::Value Responder::HandleSearchRequest(const Json::Value & request)
    {
      Json::Value  rc;
      Ipv4Address  addr(request["ipv4addr"].asString());
      TimeValue  now(true);
      for (const auto & db : _server.GetDatabases()) {
        pair<Ipv4Prefix,DbEntry>  match;
        if (db.second.FindLongest(addr, match)) {
          if (match.second.IsActive()) {
            rc[db.first]["prefix"] = match.second.Prefix().ToShortString();
            rc[db.first]["country"] = match.second.Country();
            rc[db.first]["countryName"] =
              DbEntry::GetCountryName(match.second.Country());
            TimeValue  timeLeft(match.second.Interval().End());
            timeLeft -= now;
            uint32_t  daysLeft = (timeLeft.Secs() / (24 * 60 * 60));
            rc[db.first]["daysRemaining"] = daysLeft;
          }
        }
      }
      return rc;
    }
      
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Json::Value
    Responder::HandleGetAddRulesRequest(const Json::Value & request)
    {
      Json::Value  rc;
      auto  ari = _server.GetAddRules().RulesForTable(request["table"].asString());
      if (ari != _server.GetAddRules().RulesForTables().end()) {
        rc = ari->second.Json();
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Responder::GetRDAPEntry(const Ipv4Address & addr,
                                 Ipv4Prefix & prefix, string & country)
    {
      RDAP::Fetcher                  fetcher;
      pair<bool,RDAP::FetchedEntry>  rdapResp;  rdapResp.first = false;
      if (fetcher.OpenSession(_server.Config().DwmRDAPServer(),
                              _server.Config().DwmRDAPPrivateKeyFile(),
                              _server.Config().KnownServicesFile())) {
        rdapResp = fetcher.GetEntry(addr);
        fetcher.CloseSession();
      }
      if (rdapResp.first) {
        prefix = rdapResp.second.Prefix();
        country = rdapResp.second.Country();
      }
      return;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Responder::AddToDatabase(const string & tableName,
                                  const Ipv4Prefix & prefix,
                                  uint32_t days, const string & country)
    {
      auto  dbi = _server.GetDatabases().find(tableName);
      if (dbi != _server.GetDatabases().end()) {
        TimeValue  now(true);
        TimeValue  endTime(now);
        endTime.Set(endTime.Secs() + (days * 24 * 60 * 60), 0);
        TimeInterval  ti(now, endTime);
        DbEntry  dbEntry(prefix, ti, "", country);
        dbi->second.AddEntry(dbEntry);
        dbi->second.Save("");
      }
      else {
        Syslog(LOG_ERR, "Database not found for table %s", tableName.c_str());
      }
      return;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Responder::AddToDatabase(const string & tableName,
                                  const Ipv4Prefix & prefix)
    {
      auto  dbi = _server.GetDatabases().find(tableName);
      if (dbi != _server.GetDatabases().end()) {
        DbEntry  dbEntry;
        if (dbi->second.Find(prefix, dbEntry)) {
          dbi->second.DeleteEntry(prefix);
        }
        DbEntry  newDbEntry(prefix);
        dbi->second.AddEntry(newDbEntry);
        dbi->second.Save("");
      }
      return;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Responder::RemoveFromDatabase(const string & tableName,
                                       const Ipv4Prefix & prefix)
    {
      auto  dbi = _server.GetDatabases().find(tableName);
      if (dbi != _server.GetDatabases().end()) {
        if (dbi->second.DeleteEntry(prefix)) {
          dbi->second.Save("");
        }
      }
      return;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Responder::KillSourceState(const Ipv4Prefix & prefix)
    {
      struct pfioc_state_kill psk;
      memset(&psk, 0, sizeof(psk));
      
      psk.psk_af = AF_INET;
      psk.psk_src.addr.v.a.addr.v4.s_addr = prefix.Network().Raw();
      psk.psk_src.addr.v.a.mask.v4.s_addr = prefix.Netmask().Raw();
      
      return _server.GetPfDevice().Ioctl(DIOCKILLSTATES, &psk);
    }

    //------------------------------------------------------------------------
    //!  This is gnarly.  How can I refactor it to make it easier to maintain?
    //------------------------------------------------------------------------
    Json::Value Responder::HandleLogHitRequest(const Json::Value & request)
    {
      Json::Value  rc;
      Ipv4Address  addr(request["ipv4addr"].asString());
      Pf::Table  &&pftable = _server.GetPfDevice().GetTable("", request["table"].asString());
      if (pftable.Name() == request["table"].asString()) {
        Ipv4Prefix  match;
        if (pftable.Contains(addr, match)) {
          // already blocked
          rc["requested"] = request["ipv4addr"];
          rc["matched"] = match.ToShortString();
          return rc;
        }
        else {
          //  not yet blocked.
          Ipv4Prefix  prefix(addr, 24);
          string      country("??");
          //  Try to get prefix and country from dwmrdapd.
          GetRDAPEntry(addr, prefix, country);

          AddRule  addRule;
          //  Get the rules for the table
          auto  ari = _server.GetAddRules().RulesForTable(request["table"].asString());
          if (ari != _server.GetAddRules().RulesForTables().end()) {
            //  And the rule for the country
            if (ari->second.FindRuleForCountry(country, addRule)) {
              if (addRule.WidestMask() > prefix.MaskLength()) {
                //  apppy widest mask from rule
                prefix = Ipv4Prefix(addr, addRule.WidestMask());
              }
            }
            else {
              //  no rule for country!
              prefix = Ipv4Prefix(addr, 24);
            }
            if (addRule.LogThresh() <= 1) {
              //  special case, no tracking necessary
              if (pftable.Add(prefix)) {
                KillSourceState(prefix);
                TcpDropPrefix(prefix);
                AddToDatabase(pftable.Name(), prefix, addRule.Days(), country);
                Syslog(LOG_INFO, "Added %s (%s) to %s for %u days",
                       prefix.ToShortString().c_str(), country.c_str(),
                       pftable.Name().c_str(), addRule.Days());
                rc["requested"] = request["ipv4addr"];
                rc["added"] = prefix.ToShortString();
              }
            }
            else {
              TimeValue64  logtime(request["logTime"].asUInt64(), 0);
              auto  leti = _server.GetLogEntryTrackers().find(request["table"].asString());
              if (leti != _server.GetLogEntryTrackers().end()) {
                leti->second.Add(prefix, country, logtime);
                if (leti->second.HitThreshold(prefix, ari->second)) {
                  //  Above configured hit threshold
                  if (pftable.Add(prefix)) {
                    KillSourceState(prefix);
                    TcpDropPrefix(prefix);
                    AddToDatabase(pftable.Name(), prefix, addRule.Days(),
                                  country);
                    Syslog(LOG_INFO, "Added %s (%s) to %s for %u days",
                           prefix.ToShortString().c_str(), country.c_str(),
                           pftable.Name().c_str(), addRule.Days());
                    rc["requested"] = request["ipv4addr"];
                    rc["added"] = prefix.ToShortString();
                  }
                  leti->second.Remove(prefix);
                }
                else {
                  //  Not above threshold yet
                  Syslog(LOG_INFO, "Pending %s (%s) for %s, %u/%u",
                         prefix.ToShortString().c_str(), country.c_str(),
                         pftable.Name().c_str(),
                         leti->second.CurrentHitLevel(prefix, ari->second),
                         addRule.LogThresh());
                  rc["requested"] = request["ipv4addr"];
                  rc["pending"] = prefix.ToShortString();
                }
              }
              else {
                //  No log tracker found, add it
                TableLogEntryTracker  & letRef =
                  _server.GetLogEntryTrackers()[request["table"].asString()];
                letRef.Add(prefix, country, logtime);
                Syslog(LOG_INFO, "Pending %s (%s) for %s, %u/%u",
                       prefix.ToShortString().c_str(), country.c_str(),
                       pftable.Name().c_str(),
                       letRef.CurrentHitLevel(prefix, ari->second),
                       addRule.LogThresh());
                rc["requested"] = request["ipv4addr"];
                rc["pending"] = prefix.ToShortString();
              }
            }
          }
          else {
            //  No rules for table!
          }
        }
      }
      return rc;
    }
      
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Json::Value Responder::HandleRequest(const Json::Value & request)
    {
      if (IsValidActivateRequest(request)) {
        return HandleActivateRequest(request);
      }
      else if (IsValidDeactivateRequest(request)) {
        return HandleDeactivateRequest(request);
      }
      else if (IsValidSearchRequest(request)) {
        return HandleSearchRequest(request);
      }
      else if (IsValidGetActiveRequest(request)) {
        return HandleGetActiveRequest(request);
      }
      else if (IsValidGetAddRulesRequest(request)) {
        return HandleGetAddRulesRequest(request);
      }
      else if (IsValidLogHitRequest(request)) {
        return HandleLogHitRequest(request);
      }
      else {
        Syslog(LOG_ERR, "Invalid request");
        return Json::Value();
      }
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Responder::Run()
    {
      set_pthread_name(_thread.native_handle(), "responder");
      Syslog(LOG_DEBUG, "Responder started on fd %d", (int)_socket);

      int  noDelay = 1;
      _socket.Setsockopt(IPPROTO_TCP, TCP_NODELAY, &noDelay,
                         sizeof(noDelay));
      RequestMessage  reqmsg;
      Auth::PeerAuthenticator  peerAuth(_server.Config().PrivateKeyFile(),
                                        _server.Config().AuthorizedKeysFile());
      if (peerAuth.Authenticate(_socket, _theirId, _agreedKey)) {
        Syslog(LOG_INFO, "Authenticated client %s from %s:%hu",
               _theirId.c_str(), inet_ntoa(_clientAddr.sin_addr),
               ntohs(_clientAddr.sin_port));
        while (_run && reqmsg.Read(_socket, _agreedKey)) {
          Json::Value  reqJson = reqmsg.Json();
          ResponseMessage  respmsg(HandleRequest(reqJson));
          if (! respmsg.Write(_socket, _agreedKey)) {
            Syslog(LOG_ERR, "Failed to send response message to %s:%hu",
                   inet_ntoa(_clientAddr.sin_addr),
                   ntohs(_clientAddr.sin_port));
            break;
          }
        }
        Syslog(LOG_INFO, "Done with client %s from %s:%hu",
               _theirId.c_str(), inet_ntoa(_clientAddr.sin_addr),
               ntohs(_clientAddr.sin_port));
      }
      else {
        Syslog(LOG_INFO, "Authentication failed for client from %s:%hu",
               inet_ntoa(_clientAddr.sin_addr), ntohs(_clientAddr.sin_port));
      }

      Syslog(LOG_DEBUG, "Responder done for client %s:%hu",
             inet_ntoa(_clientAddr.sin_addr), ntohs(_clientAddr.sin_port));

      _agreedKey.assign(_agreedKey.size(), '\0');
      _running = false;
      
      return;
    }

  }  // namespace McBlock

}  // namespace Dwm
