//===========================================================================
// @(#) $Name:$
// @(#) $Id: DwmIpv4CountryDb.cc 12241 2023-10-26 17:59:43Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2017
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file DwmIpv4CountryDb.cc
//!  \brief NOT YET DOCUMENTED
//---------------------------------------------------------------------------

extern "C" {
  #include <fcntl.h>
#ifdef __linux__
  #define O_SHLOCK  0
  #define O_EXLOCK  0
  #include <sys/file.h>
#endif
}

#include <fstream>
#include <regex>
#include <mutex>
#include <shared_mutex>

#include "DwmIO.hh"
#include "DwmBZ2IO.hh"
#include "DwmGZIO.hh"
#include "DwmSvnTag.hh"
#include "DwmSysLogger.hh"
#include "DwmIpv4CountryDb.hh"

static const Dwm::SvnTag svntag("@(#) $DwmPath: dwm/libDwmRDAP/tags/libDwmRDAP-0.3.4/src/DwmIpv4CountryDb.cc 12241 $");

using namespace std;

namespace Dwm {

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  Ipv4CountryDbValue::Ipv4CountryDbValue()
      : _data()
  {}

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  Ipv4CountryDbValue::Ipv4CountryDbValue(const string & code,
                                         const TimeValue64 & lastChanged,
                                         const TimeValue64 & lastUpdated)
  {
    Code(code);
    LastChanged(lastChanged);
    LastUpdated(lastUpdated);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  Ipv4CountryDbValue::Ipv4CountryDbValue(const nlohmann::json & rdapResponse)
  {

  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  const string & Ipv4CountryDbValue::Code() const
  {
    return get<0>(_data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  const string & Ipv4CountryDbValue::Code(const string code)
  {
    get<0>(_data) = code;
    return get<0>(_data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  const TimeValue64 & Ipv4CountryDbValue::LastUpdated() const
  {
    return get<1>(_data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  const TimeValue64 &
  Ipv4CountryDbValue::LastUpdated(const TimeValue64 & lastUpdated)
  {
    get<1>(_data) = lastUpdated;
    return get<1>(_data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  const TimeValue64 & Ipv4CountryDbValue::LastChanged() const
  {
    return get<2>(_data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  const TimeValue64 &
  Ipv4CountryDbValue::LastChanged(const TimeValue64 & lastChanged)
  {
    get<2>(_data) = lastChanged;
    return get<2>(_data);
  }
  
  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  istream & Ipv4CountryDbValue::Read(istream & is)
  {
    return IO::Read(is, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  ssize_t Ipv4CountryDbValue::Read(int fd)
  {
    return IO::Read(fd, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  size_t Ipv4CountryDbValue::Read(FILE * f)
  {
    return IO::Read(f, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  int Ipv4CountryDbValue::Read(gzFile gzf)
  {
    return GZIO::Read(gzf, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  int Ipv4CountryDbValue::BZRead(BZFILE *bzf)
  {
    return BZ2IO::BZRead(bzf, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  ostream & Ipv4CountryDbValue::Write(ostream & os) const
  {
    return IO::Write(os, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  ssize_t Ipv4CountryDbValue::Write(int fd) const
  {
    return IO::Write(fd, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  size_t Ipv4CountryDbValue::Write(FILE *f) const
  {
    return IO::Write(f, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  int Ipv4CountryDbValue::Write(gzFile gzf) const
  {
    return GZIO::Write(gzf, _data);
  }
  
  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  int Ipv4CountryDbValue::BZWrite(BZFILE *bzf) const
  {
    return BZ2IO::BZWrite(bzf, _data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  uint64_t Ipv4CountryDbValue::StreamedLength() const
  {
    return IO::StreamedLength(_data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  bool Ipv4CountryDbValue::operator == (const Ipv4CountryDbValue & val) const
  {
    return (_data == val._data);
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  nlohmann::json Ipv4CountryDbValue::Json() const
  {
    nlohmann::json  jv;
    DateTime  lastUpdated(LastUpdated());
    DateTime  lastChanged(LastChanged());

    jv["country"] = Code();
    jv["lastUpdated"] = lastUpdated.Formatted("%Y-%m-%d %H:%M");
    jv["lastChanged"] = lastChanged.Formatted("%Y-%m-%d %H:%M");
    
    return jv;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  Ipv4CountryDb::Ipv4CountryDb(const std::string & filename)
      : RDAP::Ipv4Routes<Ipv4CountryDbValue>(), _filename(filename)
  {
    if (! _filename.empty()) {
      int	 fd = open(_filename.c_str(), O_RDONLY|O_SHLOCK);
      if (fd >= 0) {
#ifdef __linux__
        flock(fd, LOCK_SH);
#endif
        Read(fd);
        close(fd);
      }
    }
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  bool Ipv4CountryDb::AddEntry(const RDAP::IPv4Response & rdapResponse)
  {
    bool  rc = false;
    pair<Ipv4Address,Ipv4Address>  &&addrs = rdapResponse.AddressRange();
    if ((addrs.first != INADDR_ANY) && (addrs.second != INADDR_ANY)) {
      string  &&country = rdapResponse.Country();
      if (! country.empty()) {
        DateTime  &&dt = rdapResponse.LastChanged();
        if (dt.GetTimeValue64().Secs() == 0) {
          dt = DateTime(true);
        }
        Ipv4CountryDbValue  value(country, dt.GetTimeValue64());
        vector<Ipv4Prefix>  &&prefixes =
          Ipv4RangePrefixes(addrs.first, addrs.second);
        for (auto & pfx : prefixes) {
          Add(pfx, value);
          rc = true;
        }
      }
    }
    return rc;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  bool Ipv4CountryDb::Update(const RDAP::IPv4Response & rdapResponse,
                             uint64_t expireDays)
  {
    bool  rc = false;
    pair<Ipv4Address,Ipv4Address>  &&addrs = rdapResponse.AddressRange();
    if ((addrs.first != INADDR_ANY) && (addrs.second != INADDR_ANY)) {
      string  &&country = rdapResponse.Country();
      if (! country.empty()) {
        DateTime  &&dt = rdapResponse.LastChanged();
        if (dt.GetTimeValue64().Secs() == 0) {
          dt = DateTime(true);
        }
        TimeValue64  now(true);
        Ipv4CountryDbValue  value(country, dt.GetTimeValue64(), now);
        vector<Ipv4Prefix>  &&prefixes =
          Ipv4RangePrefixes(addrs.first, addrs.second);
        for (auto & pfx : prefixes) {
          Ipv4CountryDbValue  prevValue;
          if (Find(pfx, prevValue)) {
            //  Replace existing entry.
            Delete(pfx);
            DeleteCoveredExpired(pfx, expireDays);
            rc = Add(pfx, value);
            if (rc) {
              Syslog(LOG_INFO, "Updated %s (%s)",
                     pfx.ToShortString().c_str(), country.c_str());
            }
          }
          else {
            //  No existing entry found, add it.
            rc = Add(pfx, value);
            if (rc) {
              Syslog(LOG_INFO, "Added %s (%s)",
                     pfx.ToShortString().c_str(), country.c_str());
            }
          }
        }
      }
      else {
        Syslog(LOG_ERR, "No country in RDAP response");
      }
    }
    else {
      Syslog(LOG_ERR, "No good address range in RDAP response");
    }
    
    return rc;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  bool Ipv4CountryDb::DeleteCoveredExpired(const Ipv4Prefix & prefix,
                                           uint64_t days)
  {
    bool  rc = false;
    TimeValue64  expireTime(true);
    expireTime.Set(expireTime.Secs() - (days * 24 * 60 * 60), 0);
    std::lock_guard<std::shared_mutex>  lock(this->_mtx);
    for (uint8_t i = prefix.MaskLength() + 1; i < 33; ++i) {
      auto iter = _maps[i].begin();
      while (iter != _maps[i].end()) {
        if (prefix.Contains(iter->first)) {
          if (iter->second.LastUpdated() < expireTime) {
            iter = _maps[i].erase(iter);
            rc = true;
          }
          else {
            ++iter;
          }
        }
        else {
          ++iter;
        }
      }
    }
    return rc;
  }
  
  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  bool Ipv4CountryDb::Reload()
  {
    bool  rc = false;
    Clear();
    int  fd = open(_filename.c_str(), O_RDONLY|O_SHLOCK);
    if (fd >= 0) {
#ifdef __linux__
      flock(fd, LOCK_SH);
#endif
      if (Read(fd)) {
        rc = true;
      }
      close(fd);
    }
    return rc;
  }
  
  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  bool Ipv4CountryDb::Save() const
  {
    bool  rc = false;
    int  fd = open(_filename.c_str(), O_RDWR|O_CREAT|O_TRUNC|O_EXLOCK, 0644);
    if (fd >= 0) {
#ifdef __linux__
      flock(fd, LOCK_EX);
#endif
      if (Write(fd) > 0) {
        rc = true;
        Syslog(LOG_INFO, "Database saved to %s", _filename.c_str());
      }
      else {
        Syslog(LOG_ERR, "Failed to save database to '%s'", _filename.c_str());
      }
      close(fd);
    }
    else {
      Syslog(LOG_ERR, "Failed to open database '%s'", _filename.c_str());
    }
    return rc;
  }
  
  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::PrintJson(ostream & os) const
  {
    if (os) {
      nlohmann::json  jv;
      vector<pair<Ipv4Prefix,Ipv4CountryDbValue>>  entries;
      SortByKey(entries);
      for (uint32_t i = 0; i < entries.size(); ++i) {
        string  pfx = entries[i].first.ToShortString();
        string  code = entries[i].second.Code();
        DateTime  dtu(entries[i].second.LastUpdated());
        string  lastUpdated = dtu.Formatted("%Y-%m-%d %H:%M");
        DateTime  dtc(entries[i].second.LastChanged());
        string  lastChanged = dtc.Formatted("%Y-%m-%d %H:%M");
        jv[i]["prefix"] = pfx;
        jv[i]["country"] = code;
        jv[i]["lastUpdated"] = lastUpdated;
        jv[i]["lastChanged"] = lastChanged;
      }
      os << jv.dump();
    }
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::CreateCountryAggregates(uint8_t masklen)
  {
    lock_guard<shared_mutex>  lock(this->_mtx);
    map<string, vector<Ipv4Prefix>>  countryPrefixes;
    for (auto & hm : _maps[masklen]) {
      countryPrefixes[hm.second.Code()].push_back(hm.first);
    }
    for (auto & cp : countryPrefixes) {
      sort (cp.second.begin(), cp.second.end());
      CoalesceAdjacentPrefixes(cp.second);
      Ipv4CountryDbValue  val(cp.first, TimeValue64(true));
      for (auto & p : cp.second) {
        if (p.MaskLength() < masklen) {
          _maps[p.MaskLength()][p.Network()] = val;
          Syslog(LOG_INFO, "Created aggregate %s %s from /%hhus",
                 cp.first.c_str(), p.ToShortString().c_str(), masklen);
        }
      }
    }
    return;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::DeleteRedundantSpecifics()
  {
#if 0
    lock_guard<shared_mutex>  lock(this->_mtx);
    for (int i = 32; i > 0; --i) {
      auto  hm = _maps[i];
      for (auto  pi = hm.begin(); pi != hm.end(); ++pi) {
        for (int j = i - 1; j > 0; --j) {
          Ipv4Prefix  pfx(pi->first, j);
          auto  wi = _maps[j].find(pfx.Network());
          if (wi != _maps[j].end()) {
            if (wi->second.Code() == pi->second.Code()) {
              pi = hm.erase(pi);
              break;
            }
          }
        }
      }
    }
#endif
    return;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::CreateCountryAggregates()
  {
    for (uint8_t i = 32; i > 0; --i) {
      CreateCountryAggregates(i);
    }
    return;
  }
  
  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  vector<Ipv4Prefix>
  Ipv4CountryDb::PrefixesForCountries(const vector<string> & countries,
                                      bool minimalSet)
  {
    vector<Ipv4Prefix>  rc;
    uint8_t             masklen = 0;
    shared_lock<shared_mutex>  lock(_mtx);
    for (auto const & hm : _maps) {
      for (auto const & v : hm) {
        if (find_if(countries.begin(), countries.end(),
                    [&] (const string & country) {
                      return (v.second.Code() == country);
                    }) != countries.end()) {
          rc.push_back(Ipv4Prefix(v.first, masklen));
        }
      }
      ++masklen;
    }
    if (minimalSet) {
      CoalesceContainedPrefixes(rc);
      sort(rc.begin(), rc.end());
      CoalesceAdjacentPrefixes(rc);
    }
    return rc;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  vector<Ipv4Prefix>
  Ipv4CountryDb::PrefixesForCountries(const string & regExpression,
                                      bool minimalSet)
  {
    vector<Ipv4Prefix>  rc;
    regex   rgx(regExpression, regex::ECMAScript|regex::optimize);
    smatch  sm;

    uint8_t  masklen = 1;
    for (masklen = 1; masklen < 33; ++masklen) {
      for (auto const & v : _maps[masklen]) {
        if (regex_search(v.second.Code(),sm,rgx)) {
          Ipv4Prefix  pfx(v.first, masklen);
          rc.push_back(pfx);
        }
      }
    }
    if (minimalSet) {
      CoalesceContainedPrefixes(rc);
      sort(rc.begin(), rc.end());
      CoalesceAdjacentPrefixes(rc);
    }
    
    return rc;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  static bool PrefixesCombine(const Ipv4Prefix & p1, const Ipv4Prefix & p2)
  {
    return p1.Combine(p2).first;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::CoalesceAdjacentPrefixes(vector<Ipv4Prefix> & vp)
  {
    bool  done = false;
    while (! done) {
      done = true;
      vector<Ipv4Prefix>::iterator  it = vp.begin();
      while ((it = adjacent_find(it, vp.end(), PrefixesCombine)) != vp.end()) {
        (*it) = it->Combine(*(it + 1)).second;
        if (vp.erase(it + 1) == vp.end()) {
          done = true;
          break;
        }
        else {
          done = false;
        }
      }
    }
    return;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  static bool ContainedPrefixLess(const Ipv4Prefix & p1, const Ipv4Prefix & p2)
  {
    bool  rc = false;
    if (p1.Network() < p2.Network()) {
      rc = true;
    }
    else if (p1.Network() == p2.Network()) {
      rc = (p1.MaskLength() < p2.MaskLength());
    }
    return rc;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::CoalesceContainedPrefixes(vector<Ipv4Prefix> & vp)
  {
    sort(vp.begin(), vp.end(), ContainedPrefixLess);
    bool  done = false;
    while (! done) {
      done = true;
      vector<Ipv4Prefix>::iterator  it = vp.begin();
      while ((it = adjacent_find(it, vp.end(),
                                 [] (const Ipv4Prefix & p1,
                                     const Ipv4Prefix & p2)
                                 { return p1.Contains(p2); })) != vp.end()) {
        if (vp.erase(it + 1) == vp.end()) {
          done = true;
          break;
        }
        else {
          done = false;
        }
      }
    }
    return;
  }

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  class AddressesCombine
  {
  public:
    AddressesCombine(uint8_t masklen)
        : _masklen(masklen)
    {}
    
    bool operator () (const pair<Ipv4Address, Ipv4CountryDbValue> & a1,
                      const pair<Ipv4Address, Ipv4CountryDbValue> & a2) const
    {
      Ipv4Prefix  p1(a1.first,_masklen);
      Ipv4Prefix  p2(a2.first,_masklen);
      return (p1.Combine(p2).first
              && (a1.second.Code() == a2.second.Code()));
    }
  private:
    uint8_t  _masklen;
    AddressesCombine();
  };

  //--------------------------------------------------------------------------
  //!  
  //--------------------------------------------------------------------------
  void Ipv4CountryDb::AggregateAdjacents()
  {
    for (uint8_t masklen = 32; masklen > 0; --masklen) {
      lock_guard<shared_mutex>  lock(_mtx);
      map<Ipv4Address, Ipv4CountryDbValue>  & addrs = _maps[masklen];
      AddressesCombine  addrCombine(masklen);
      bool  done = false;
      while (! done) {
        done = true;
        auto  it = addrs.begin();
        while ((it = adjacent_find(it, addrs.end(), addrCombine))
               != addrs.end()) {
          auto  ei = it;
          ++ei;
          Ipv4Prefix  p1(it->first, masklen);
          Ipv4Prefix  p2(ei->first, masklen);
          pair<Ipv4Address, Ipv4CountryDbValue>
            entry(p1.Combine(p2).second.Network(), it->second);
          it = addrs.erase(it, ++ei);
          _maps[masklen - 1].insert(entry);
        }
      }
    }
    return;
  }
  

}  // namespace Dwm
