//===========================================================================
// @(#) $Name:$
// @(#) $Id: DwmPfTable.cc 12347 2024-05-03 02:40:01Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2017
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file DwmPfTable.cc
//!  \brief NOT YET DOCUMENTED
//---------------------------------------------------------------------------

extern "C" {
  #include <sys/ioctl.h>
  #include <sys/types.h>
  #include <sys/socket.h>
  #include <net/if.h>
  #include <netinet/in.h>
  #include <netinet/tcp.h>
  #include <arpa/inet.h>
  #include <fcntl.h>
  #include <unistd.h>
}

#include <iomanip>
#include <new>

#include "DwmSvnTag.hh"
#include "DwmSysLogger.hh"
#include "DwmPfDevice.hh"

static const Dwm::SvnTag svntag("@(#) $DwmPath: dwm/libDwmPf/tags/libDwmPf-0.1.6/src/DwmPfTable.cc 12347 $");

using namespace std;

namespace Dwm {

  namespace Pf {

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::IOCTable::IOCTable(const Table & table)
    {
      memset(&_tbl, 0, sizeof(_tbl));
      if (! table._anchor.empty()) {
        if (strlcpy(_tbl.pfrio_table.pfrt_anchor, table._anchor.c_str(),
                    sizeof(_tbl.pfrio_table.pfrt_anchor))
            >= sizeof(_tbl.pfrio_table.pfrt_anchor)) {
          throw std::length_error(table._anchor);
        }
      }
      if (! table._name.empty()) {
        if (strlcpy(_tbl.pfrio_table.pfrt_name, table._name.c_str(),
                    sizeof(_tbl.pfrio_table.pfrt_name))
            >= sizeof(_tbl.pfrio_table.pfrt_name)) {
          throw std::length_error(table._name);
        }
      }
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::IOCTable::IOCTable()
    {
      memset(&_tbl, 0, sizeof(_tbl));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::IOCTable::IOCTable(const Table::IOCTable & iocTable)
    {
      memcpy(&_tbl, &iocTable._tbl, sizeof(_tbl));
      if (iocTable._tbl.pfrio_buffer && iocTable._tbl.pfrio_size) {
        _tbl.pfrio_buffer =
          malloc(iocTable._tbl.pfrio_size * iocTable._tbl.pfrio_esize);
        if (nullptr != _tbl.pfrio_buffer) {
          memcpy(_tbl.pfrio_buffer, iocTable._tbl.pfrio_buffer,
                 iocTable._tbl.pfrio_size * iocTable._tbl.pfrio_esize);
        }
        else {
          throw std::bad_alloc();
        }
      }
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::IOCTable &
    Table::IOCTable::operator = (const Table::IOCTable & iocTable)
    {
      Syslog(LOG_DEBUG, "IOCTable::operator = called");
      if (&iocTable != this) {
        ClearBuffer();
        memcpy(&_tbl, &iocTable._tbl, sizeof(_tbl));
        if (iocTable._tbl.pfrio_buffer && iocTable._tbl.pfrio_size) {
          _tbl.pfrio_buffer =
            malloc(iocTable._tbl.pfrio_size * iocTable._tbl.pfrio_esize);
          if (nullptr != _tbl.pfrio_buffer) {
            memcpy(_tbl.pfrio_buffer, iocTable._tbl.pfrio_buffer,
                   iocTable._tbl.pfrio_size * iocTable._tbl.pfrio_esize);
          }
          else {
            throw std::bad_alloc();
          }
        }
      }
      return *this;
    }
      
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::IOCTable::AddPrefix(const Ipv4Prefix & prefix)
    {
      bool  rc = false;
      if (AllocPfrAddrs(1)) {
        struct pfr_addr  *newaddr = (struct pfr_addr *)_tbl.pfrio_buffer;
        newaddr->pfra_af = AF_INET;
        newaddr->pfra_ip4addr.s_addr = prefix.NetworkRaw();
        newaddr->pfra_net = prefix.MaskLength();
        newaddr->pfra_not = 0;
        newaddr->pfra_fback = 0;
        rc = true;
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::IOCTable::AddPrefixes(const vector<Ipv4Prefix> & prefixes)
    {
      bool  rc = false;
      if (! prefixes.empty()) {
        if (AllocPfrAddrs(prefixes.size())) {
          struct pfr_addr  *newaddr = (struct pfr_addr *)_tbl.pfrio_buffer;
          for (auto pfx : prefixes) {
            newaddr->pfra_af = AF_INET;
            newaddr->pfra_ip4addr.s_addr = pfx.NetworkRaw();
            newaddr->pfra_net = pfx.MaskLength();
            newaddr->pfra_not = 0;
            newaddr->pfra_fback = 0;
            ++newaddr;
          }
          rc = true;
        }
      }
      return rc;
    }
            
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::IOCTable::~IOCTable()
    {
      Clear();
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Table::IOCTable::Clear()
    {
      ClearBuffer();
      memset(&_tbl, 0, sizeof(_tbl));
      return;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Table::IOCTable::ClearBuffer()
    {
      if (_tbl.pfrio_buffer) {
        free(_tbl.pfrio_buffer);
        _tbl.pfrio_buffer = 0;
      }
      _tbl.pfrio_esize = 0;
      _tbl.pfrio_size = 0;
      return;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::IOCTable::AllocPfrAddrs(int numAddrs)
    {
      ClearBuffer();

      if (numAddrs) {
        _tbl.pfrio_buffer = calloc(1, sizeof(pfr_addr) * numAddrs);
        if (_tbl.pfrio_buffer) {
          _tbl.pfrio_esize = sizeof(pfr_addr);
          _tbl.pfrio_size = numAddrs;
        }
        else {
          Syslog(LOG_ERR, "calloc(1,%u) failed: %m",
                 sizeof(pfr_addr) * numAddrs);
        }
      }
      return (_tbl.pfrio_size > 0);
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::IOCTable::AllocPfrTables(int numTables)
    {
      ClearBuffer();
      if (numTables > 0) {
        _tbl.pfrio_buffer = calloc(1, sizeof(pfr_table) * numTables);
        if (_tbl.pfrio_buffer) {
          _tbl.pfrio_esize = sizeof(pfr_table);
          _tbl.pfrio_size = numTables;
        }
        else {
          Syslog(LOG_ERR, "calloc(1,%u) failed: %m",
                 sizeof(pfr_table) * numTables);
        }
      }
      return (_tbl.pfrio_size > 0);
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::IOCTable::AllocPfrAstats(int numAddrs)
    {
      ClearBuffer();
      if (numAddrs > 0) {
        _tbl.pfrio_buffer = calloc(1, sizeof(pfr_astats) * numAddrs);
        if (_tbl.pfrio_buffer) {
          _tbl.pfrio_esize = sizeof(pfr_astats);
          _tbl.pfrio_size = numAddrs;
        }
        else {
          Syslog(LOG_ERR, "calloc(1,%u) failed: %m",
                 sizeof(pfr_astats) * numAddrs);
        }
      }
      return (_tbl.pfrio_size > 0);
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::Table(const Device & dev, const string & anchor,
                 const string & name, uint32_t flags,
                 uint8_t fback)
        : _dev(dev), _anchor(anchor), _name(name),
          _flags(flags), _fback(fback)
    {}

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::Table(const Device & dev, const struct pfr_table & pfrTable)
        : _dev(dev), _anchor(pfrTable.pfrt_anchor), _name(pfrTable.pfrt_name),
          _flags(pfrTable.pfrt_flags), _fback(pfrTable.pfrt_fback)
    {
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Table::Table(const Table & table)
        : _dev(table._dev), _anchor(table._anchor), _name(table._name),
          _flags(table._flags), _fback(table._fback)
    {}

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    const string & Table::Anchor() const
    {
      return _anchor;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    const std::string & Table::Name() const
    {
      return _name;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Add(const Ipv4Prefix & prefix) const
    {
      bool  rc = false;
      if (_dev.IsOpen()) {
        IOCTable  ioctbl(*this);
        ioctbl.AddPrefix(prefix);
        if (_dev.Ioctl(DIOCRADDADDRS, ioctbl.pfioc())) {
          if (ioctbl.pfioc()->pfrio_nadd == 1) {
            rc = true;
            Syslog(LOG_DEBUG, "Added %s to %s",
                   prefix.ToShortString().c_str(), _name.c_str());
          }
          else {
            Syslog(LOG_ERR, "Failed to add %s to %s",
                   prefix.ToShortString().c_str(), _name.c_str());
          }
        }
      }
      else {
        Syslog(LOG_ERR, "Pf::Device is not open");
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Add(const Ipv4Address & address) const
    {
      return Add(Ipv4Prefix(address, 32));
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Remove(const Ipv4Prefix & prefix) const
    {
      bool  rc = false;
      if (_dev.IsOpen()) {
        IOCTable  ioctbl(*this);
        ioctbl.AddPrefix(prefix);
        if (_dev.Ioctl(DIOCRDELADDRS, ioctbl.pfioc())) {
          if (ioctbl.pfioc()->pfrio_ndel == 1) {
            rc = true;
            Syslog(LOG_DEBUG, "Removed %s from %s",
                   prefix.ToShortString().c_str(), _name.c_str());
          }
        }
      }
      else {
        Syslog(LOG_ERR, "Pf::Device is not open");
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Remove(const Ipv4Address & address) const
    {
      return Remove(Ipv4Prefix(address, 32));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Contains(const Ipv4Prefix & prefix, Ipv4Prefix & result) const
    {
      bool  rc = false;
      std::vector<Ipv4Prefix>  &&entries = GetEntries();
      if (! entries.empty()) {
        for (auto & entry : entries) {
          if (entry.Contains(prefix)) {
            result = entry;
            rc = true;
            break;
          }
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Contains(const Ipv4Address & address,
                         Ipv4Prefix & result) const
    {
      bool  rc = false;
      if (_dev.IsOpen()) {
        IOCTable  ioctbl(*this);
        ioctbl.AddPrefix(Ipv4Prefix(address, 32));
        ioctbl.pfioc()->pfrio_flags |= PFR_FLAG_REPLACE;
        if (_dev.Ioctl(DIOCRTSTADDRS, ioctbl.pfioc())) {
          if (ioctbl.pfioc()->pfrio_nmatch) {
            pfr_addr *addr =
              (pfr_addr *)(ioctbl.pfioc()->pfrio_buffer);
            result.Set(addr->pfra_ip4addr.s_addr, addr->pfra_net);
            rc = true;
          }
        }
      }
      else {
        Syslog(LOG_ERR, "Pf::Device is not open");
      }
      return rc;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    std::vector<Ipv4Prefix> Table::GetEntries() const
    {
      std::vector<Ipv4Prefix>  entries;
      if (_dev.IsOpen()) {
        IOCTable  ioctbl(*this);
        if (ioctbl.AllocPfrAddrs(1)) {
          pfioc_table  *ioctblp = ioctbl.pfioc();
          if (_dev.Ioctl(DIOCRGETADDRS, ioctblp)) {
            int  numAddrs = ioctblp->pfrio_size;
            if (numAddrs > 0) {
              if (ioctbl.AllocPfrAddrs(numAddrs)) {
                if (_dev.Ioctl(DIOCRGETADDRS, ioctblp)) {
                  pfr_addr  *addr = (pfr_addr *)ioctblp->pfrio_buffer;
                  for (int i = 0; i < numAddrs; ++i) {
                    if (addr->pfra_af == AF_INET) {
                      entries.push_back(Ipv4Prefix(addr->pfra_ip4addr.s_addr,
                                                   addr->pfra_net));
                    }
                    ++addr;
                  }
                }
              }
              else {
                Syslog(LOG_ERR, "Failed to allocate %d pfr_addrs", numAddrs);
              }
            }
          }
        }
        else {
          Syslog(LOG_ERR, "Failed to allocate a pfr_addr");
        }
      }
      else {
        Syslog(LOG_ERR, "Pf::Device is not open");
      }
      
      return entries;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    std::vector<TableEntryStat> Table::GetStats() const
    {
      std::vector<TableEntryStat>  stats;
      if (_dev.IsOpen()) {
        IOCTable  ioctbl(*this);
        if (ioctbl.AllocPfrAstats(1)) {
          pfioc_table  *tblp = ioctbl.pfioc();
          if (_dev.Ioctl(DIOCRGETASTATS, tblp)) {
            int  numAddrs = tblp->pfrio_size;
            if (ioctbl.AllocPfrAstats(numAddrs)) {
              if (_dev.Ioctl(DIOCRGETASTATS, tblp)) {
                pfr_astats  *asp = (pfr_astats *)tblp->pfrio_buffer;
                for (int i = 0; i < numAddrs; ++i) {
                  if (asp->pfras_a.pfra_af == AF_INET) {
                    stats.push_back(TableEntryStat(*asp));
                  }
                  ++asp;
                }
              }
            }
            else {
              Syslog(LOG_ERR, "Failed to allocate %d pfr_astats", numAddrs);
            }
          }
        }
        else {
          Syslog(LOG_ERR, "Failed to allocate pfr_astats");
        }
      }
      else {
        Syslog(LOG_ERR, "Pf::Device is not open");
      }
      return stats;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Persist() const
    {
      uint32_t  flags = GetFlags();
      return ((flags & PFR_TFLAG_PERSIST) == PFR_TFLAG_PERSIST);
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Persist(bool persist) const
    {
      bool      rc = true;
      uint32_t  flags = GetFlags();
      uint32_t  newFlags = (persist ?
                            (flags | PFR_TFLAG_PERSIST)
                            : (flags & (~PFR_TFLAG_PERSIST)));
      if (flags != newFlags) {
        rc = SetFlags(newFlags);
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Constant() const
    {
      bool      rc = false;
      uint32_t  flags = GetFlags();
      return ((flags & PFR_TFLAG_CONST) == PFR_TFLAG_CONST);
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::Constant(bool constant) const
    {
      bool      rc = true;
      uint32_t  flags = GetFlags();
      uint32_t  newFlags = (constant ?
                            (flags | PFR_TFLAG_CONST)
                            : (flags & (~PFR_TFLAG_CONST)));
      if (flags != newFlags) {
        rc = SetFlags(newFlags);
      }
      return rc;
    }

#ifndef __APPLE__
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::CountersEnabled() const
    {
      uint32_t  flags = GetFlags();
      return ((flags & PFR_TFLAG_COUNTERS) == PFR_TFLAG_COUNTERS);
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::CountersEnabled(bool enabled) const
    {
      bool      rc = true;
      uint32_t  flags = GetFlags();
      uint32_t  newFlags = (enabled ?
                            (flags | PFR_TFLAG_COUNTERS)
                            : (flags & (~PFR_TFLAG_COUNTERS)));
      if (flags != newFlags) {
        rc = SetFlags(newFlags);
      }
      return rc;
    }
#endif
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    uint32_t Table::GetFlags() const
    {
      uint32_t  rc = 0;
      if (_dev.IsOpen()) {
        Table  && t = _dev.GetTable(_anchor, _name);
        if ((t._anchor == _anchor) && (t._name == _name)) {
          rc = t._flags;
        }
        else {
          Syslog(LOG_ERR, "Table (\"%s\",\"%s\") not found",
                 _anchor.c_str(), _name.c_str());
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Table::SetFlags(uint32_t flags) const
    {
      bool  rc = false;
      if ((flags & (~PFR_TFLAG_ALLMASK)) == 0) {
        if (_dev.TableExists(*this)) {
          uint32_t  prevFlags = GetFlags();
          if (prevFlags != flags) {
            Table::IOCTable  ioctbl(*this);
            ioctbl.AllocPfrTables(1);
            pfioc_table  *tblp = ioctbl.pfioc();
            tblp->pfrio_setflag = (~(prevFlags & flags)) & flags;
            tblp->pfrio_clrflag = (~flags) & prevFlags;
            pfr_table  *newtbl = (pfr_table *)tblp->pfrio_buffer;
            strlcpy(newtbl->pfrt_anchor, _anchor.c_str(),
                    sizeof(newtbl->pfrt_anchor));
            strlcpy(newtbl->pfrt_name, _name.c_str(),
                    sizeof(newtbl->pfrt_name));
            if (_dev.Ioctl(DIOCRSETTFLAGS, tblp)) {
              if ((prevFlags & PFR_TFLAG_PERSIST)
                  && (! (tblp->pfrio_clrflag & PFR_TFLAG_PERSIST))) {
                if (tblp->pfrio_nchange == 1) {
                  rc = true;
                }
              }
              else {
                if ((tblp->pfrio_nchange + tblp->pfrio_ndel) == 1) {
                  rc = true;
                }
              }
            }
          }
          else {
            // flags already set correctly
            rc = true;
          }
        }
      }
      return rc;
    }
    

  }  // namespace Pf

}  // namespace Dwm
