//===========================================================================
// @(#) $DwmPath: dwm/DwmDns/tags/DwmDns-0.1.0/classes/src/DwmDnsMessage.cc 10133 $
// @(#) $Id: DwmDnsMessage.cc 10133 2018-01-27 17:41:32Z dwm $
//===========================================================================
//  Copyright (c) Daniel W. McRobb 2000, 2016, 2018
//  All rights reserved.
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions
//  are met:
//
//  1. Redistributions of source code must retain the above copyright
//     notice, this list of conditions and the following disclaimer.
//  2. Redistributions in binary form must reproduce the above copyright
//     notice, this list of conditions and the following disclaimer in the
//     documentation and/or other materials provided with the distribution.
//  3. The names of the authors and copyright holders may not be used to
//     endorse or promote products derived from this software without
//     specific prior written permission.
//
//  IN NO EVENT SHALL DANIEL W. MCROBB BE LIABLE TO ANY PARTY FOR
//  DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
//  INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE,
//  EVEN IF DANIEL W. MCROBB HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
//  DAMAGE.
//
//  THE SOFTWARE PROVIDED HEREIN IS ON AN "AS IS" BASIS, AND
//  DANIEL W. MCROBB HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT,
//  UPDATES, ENHANCEMENTS, OR MODIFICATIONS. DANIEL W. MCROBB MAKES NO
//  REPRESENTATIONS AND EXTENDS NO WARRANTIES OF ANY KIND, EITHER
//  IMPLIED OR EXPRESS, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
//  WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE,
//  OR THAT THE USE OF THIS SOFTWARE WILL NOT INFRINGE ANY PATENT,
//  TRADEMARK OR OTHER RIGHTS.
//===========================================================================

//---------------------------------------------------------------------------
//!  \file DwmDnsMessage.cc
//!  \brief Dwm::Dns::Message class implementation
//---------------------------------------------------------------------------

extern "C" {
  #include <sys/types.h>
  #include <sys/socket.h>
}

#include <iostream>
#include <stdexcept>

#include "DwmSvnTag.hh"
#include "DwmDnsMessage.hh"

static const Dwm::SvnTag svntag("@(#) $DwmPath: dwm/DwmDns/tags/DwmDns-0.1.0/classes/src/DwmDnsMessage.cc 10133 $");

using namespace std;

namespace Dwm {

  namespace Dns {

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    Message::Message()
        : _header(), _questions(), _answers(), _authorities(), _additional()
    {}

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    bool Message::operator == (const Message & message) const
    {
      return ((_header == message._header)
              && (_questions == message._questions)
              && (_answers == message._answers)
              && (_authorities == message._authorities)
              && (_additional == message._additional));
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Message::Clear()
    {
      _header.Clear();
      _questions.clear();
      _answers.clear();
      _authorities.clear();
      _additional.clear();
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    uint8_t *Message::Encode(uint8_t *pkt, uint8_t *ptr,
                             uint16_t pktlen) const
    {
      bool  stop = false;
      _header.QuestionCount(_questions.size());
      _header.AnswerCount(_answers.size());
      _header.AuthorityCount(_authorities.size());
      _header.AdditionalCount(_additional.size());
      ptr = _header.Encode(pkt, ptr, pktlen);
      LabelPositions  lps;
      for (auto & question : _questions) {
        ptr = question.Encode(pkt, ptr, pktlen, lps);
      }
      for (uint16_t i = 0; i < _answers.size(); ++i) {
        uint8_t  *prevptr = ptr;
        try {
          ptr = _answers[i].Encode(pkt, ptr, pktlen, lps);
        }
        catch (out_of_range & oor) {
          stop = true;
          _header.IsTruncated(true);
          _header.AnswerCount(i);
          _header.AuthorityCount(0);
          _header.AdditionalCount(0);
          _header.Encode(pkt, pkt, pktlen);
          ptr = prevptr;
          break;
        }
      }
      if (! stop) {
        for (uint16_t i = 0; i < _authorities.size(); ++i) {
          uint8_t  *prevptr = ptr;
          try {
            ptr = _authorities[i].Encode(pkt, ptr, pktlen, lps);
          }
          catch (out_of_range & oor) {
            stop = true;
            _header.IsTruncated(true);
            _header.AuthorityCount(i);
            _header.AdditionalCount(0);
            _header.Encode(pkt, pkt, pktlen);
            ptr = prevptr;
            break;
          }
        }
      }
      if (! stop) {
        for (uint16_t i = 0; i < _additional.size(); ++i) {
          uint8_t  *prevptr = ptr;
          try {
            ptr = _additional[i].Encode(pkt, ptr, pktlen, lps);
          }
          catch (out_of_range & oor) {
            _header.IsTruncated(true);
            _header.AdditionalCount(i);
            _header.Encode(pkt, pkt, pktlen);
            ptr = prevptr;
            break;
          }
        }
      }
      return ptr;
    }
    
    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    const uint8_t *Message::Decode(const uint8_t *pkt, const uint8_t *ptr,
                                   uint16_t pktlen)
    {
      _questions.clear();
      _answers.clear();
      _authorities.clear();
      _additional.clear();
      
      ptr = _header.Decode(pkt, ptr, pktlen);
      for (uint16_t i = 0; i < _header.QuestionCount(); ++i) {
        MessageQuestion  question;
        ptr = question.Decode(pkt, ptr, pktlen);
        _questions.push_back(question);
      }
      for (uint16_t i = 0; i < _header.AnswerCount(); ++i) {
        ResourceRecord  answer;
        ptr = answer.Decode(pkt, ptr, pktlen);
        _answers.push_back(answer);
      }
      for (uint16_t i = 0; i < _header.AuthorityCount(); ++i) {
        ResourceRecord  authority;
        ptr = authority.Decode(pkt, ptr, pktlen);
        _authorities.push_back(authority);
      }
      for (uint16_t i = 0; i < _header.AdditionalCount(); ++i) {
        ResourceRecord  additional;
        ptr = additional.Decode(pkt, ptr, pktlen);
        _additional.push_back(additional);
      }
      return ptr;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    ssize_t Message::SendTo(int sockFd, int flags, const struct sockaddr *to,
                            int tolen) const
    {
      ssize_t  rc = -1;
      uint8_t  pkt[4096];
      try {
        uint8_t  *ptr = pkt;
        ptr = Encode(pkt, ptr, 4096);
        rc = sendto(sockFd, pkt, (ptr - pkt), flags, to, tolen);
      }
      catch (...) {
        rc = -1;
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    ssize_t Message::Write(int fd) const
    {
      ssize_t  rc = -1;
      if (fd >= 0) {
        uint8_t  pkt[4096];
        try {
          uint8_t   *ptr = pkt + 2;
          ptr = Encode(pkt + 2, ptr, 4094);
          uint16_t  *msglen = (uint16_t *)pkt;
          *msglen = htons(ptr - (pkt + 2));
          uint16_t  pdulen = ntohs(*msglen) + 2;
          if (write(fd, &pkt, pdulen) == pdulen) {
            rc = pdulen;
          }
        }
        catch (...) {
          rc = -1;
        }
      }
      return rc;
    }
    

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    ssize_t Message::RecvFrom(int sockFd, int flags, sockaddr *from,
                              socklen_t *fromlen)
    {
      ssize_t  rc = -1;
      uint8_t  pkt[4096];
      try {
        const uint8_t	 *ptr = pkt;
        rc = recvfrom(sockFd, pkt, 4096, flags, from, fromlen);
        ptr = Decode(pkt, ptr, rc);
      }
      catch (out_of_range & ex) {
        rc = -1;
      }
      catch (...) {
        rc = -1;
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    ssize_t Message::Read(int fd)
    {
      ssize_t  rc = 0;
      if (fd >= 0) {
        uint16_t  msglen;
        if (read(fd, &msglen, sizeof(msglen)) == sizeof(msglen)) {
          msglen = ntohs(msglen);
          uint8_t  *buf = (uint8_t *)calloc(1, msglen);
          if (buf) {
            uint16_t  bytesRead = 0;
            while (bytesRead < msglen) {
              ssize_t  b = read(fd, buf + bytesRead, msglen - bytesRead);
              if (b < 1) {
                break;
              }
              bytesRead += b;
            }
            if (bytesRead == msglen) {
              Decode(buf, buf, msglen);
              rc = sizeof(msglen) + msglen;
            }
            else {
              rc = -1;
            }
            free(buf);
          }
          else {
            rc = -1;
          }
        }
        else {
          rc = -1;
        }
      }
      return rc;
    }

    //------------------------------------------------------------------------
    //!  
    //------------------------------------------------------------------------
    void Message::EnableEDNS(uint16_t pduSize, bool doDnssec)
    {
      bool  enabled = false;
      for (auto & arr : _additional) {
        if (arr.Type() == ResourceRecord::k_typeOPT) {
          enabled = true;
          arr.Class(pduSize);
          if (doDnssec) {
            arr.TTL(arr.TTL() | 0x8000);
          }
          else {
            arr.TTL(arr.TTL() & (~0x8000));
          }
          break;
        }
      }
      if (! enabled) {
        ResourceRecord  optrr;
        optrr.Type(ResourceRecord::k_typeOPT);
        optrr.Class(pduSize);
        if (doDnssec) {
          optrr.TTL(0x8000);
        }
        else {
          optrr.TTL(optrr.TTL() & (~0x8000));
        }
        _additional.push_back(optrr);
      }
    }

    
  }  // namespace Dns

}  // namespace Dwm
