How to use a specific DNS server/nameserver for name resolve queries in C++ boost::asio?

536 Views Asked by At

I would like to bypass system configured nameserver and use my own nameserver (or list of them) configured in application. I can do in nslookup in windows. How can I do it in C++ preferably by using boost::asio? I would like to avoid using std::system("nslookup ...> output.txt") and reading the file. I cannot see where I can specify nameserver to use for lookup in boost::asio.

#include <boost/asio.hpp>
#include <string>
#include <iostream>
using namespace boost;
int main()
{

asio::io_service io_service;
asio::ip::tcp::resolver resolver(io_service);//how to pass specific nameserver?

asio::ip::tcp::resolver::iterator itr = resolver.resolve("bbc.co.uk","");
asio::ip::tcp::resolver::iterator end;

for (int i = 1; itr != end; itr++, i++)
  std::cout << "hostname #" << i << ": " << itr->host_name() << "  " << itr->endpoint() << '\n';
   return 0;
}
3

There are 3 best solutions below

0
Shane Powell On

You can't. The DNS resvoling is handled by the sockets api and you can't specify the DNS servers.

You will have to directly use either a OS specific API to resolve names like DnsQueryEx in win32 or using a library like LDNS.

0
Marian K. On
/*
* DNSResolver.h
*
*  Created on: Jan 12, 2023
*      Author: marian
*/

#ifndef DNSRESOLVER_H_
#define DNSRESOLVER_H_


#include <cstdint>
#include <cstddef>
#include <vector>
#include <string_view>
#include <string>
#include <ostream>
#include <iostream>
#include <thread>
#include <chrono>
#include <boost/asio.hpp>

namespace tools
{
template <typename Func>
std::string_view  find_first_success(const std::string_view& str,const Func& func, size_t pos0=0, char delim=',')
{
    size_t pos=0;
    do
    {
        pos=str.find_first_of(delim,pos0);
        std::string_view sv=str.substr(pos0,pos-pos0);
        if (func(sv))
            return sv;
        pos0=pos+1;
    } while (pos!=std::string_view::npos);
    return {};
}

template<int unused=0>
uint16_t be2uint16(const uint8_t* hi)
{
    uint16_t hi1=*(hi+1);
    return (((uint16_t)*hi)<<8)|hi1;
}

template<int unused=0>
uint32_t be2uint32(const uint8_t* hi)
{
    uint32_t v3=*(hi+3);
    uint32_t v2=*(hi+2);
    uint32_t v1=*(hi+1);
    uint32_t v0=*hi;
    return (v0<<24)|(v1<<16)|(v2<<8)|v3;
}

template<int unused=0>
void out_IP4(std::ostream& os,uint32_t IP4)
{
    os << (IP4>>24) << '.' << ((IP4>>16)&0xFF) << '.' << ((IP4>>8)&0xFF) << '.' << (IP4&0xFF);
}
}
namespace DNS
{
const char *rcode_msg[]={"No error","Format error","Server failure","Name error","Not implemented","Refused"};
const char *opcode_msg[]={"Standard query","Inverse query","Status"};
const char *type_msg[]={"0","A","NS","MD","MF","CNAME","SOA","MB","MG","MR","NULL","WKS","PTR","HINFO","MINFO","MX","TX"};

/*
A               1 a host address
NS              2 an authoritative name server
MD              3 a mail destination (Obsolete - use MX)
MF              4 a mail forwarder (Obsolete - use MX)
CNAME           5 the canonical name for an alias
SOA             6 marks the start of a zone of authority
MB              7 a mailbox domain name (EXPERIMENTAL)
MG              8 a mail group member (EXPERIMENTAL)
MR              9 a mail rename domain name (EXPERIMENTAL)
NULL            10 a null RR (EXPERIMENTAL)
WKS             11 a well known service description
PTR             12 a domain name pointer
HINFO           13 host information
MINFO           14 mailbox or mail list information
MX              15 mail exchange
TXT             16 text strings
*/
template<int unused=0>
void disp_type(std::ostream&os,uint16_t type)
{
    os << (int)type;
    if (type<=16)
        os << " " << type_msg[type];
}


typedef boost::asio::detail::socket_option::integer<SOL_SOCKET, SO_RCVTIMEO> rcv_timeout_option; //somewhere in your headers to be used everywhere you need it


template<int unused=0>
class resolver
{
public:
    resolver(const std::string_view& nameservers="127.0.0.53",boost::asio::chrono::milliseconds timeout=boost::asio::chrono::milliseconds{3000})
    :nameservers_{nameservers},timeout_(timeout)
    {}

    std::vector<uint32_t> resolve(
                const std::string_view& hostname,
                std::string_view& nameserver_used,
                std::vector<uint32_t>* pTTLs=nullptr)
    {

        std::vector<uint8_t> result;
        result.resize(1024);
        nameserver_used=resolve(hostname,result);
        if (nameserver_used.empty())
            return {};

        std::string qname;
        std::vector<uint32_t> IP4s;
        int cnt=answer_A(result,qname,IP4s,pTTLs);
        (void)cnt;
        if (hostname!=qname)
            return {};
        return IP4s;
    }

    void resolve_ostream(
                const std::string_view& hostname,
                std::string_view& nameserver_used,
                std::ostream& os)
    {
        std::vector<uint8_t> result;
        result.resize(1024);
        nameserver_used=resolve(hostname,result);
        if (nameserver_used.empty())
            return;
        os << "nameserver:" << nameserver_used << std::endl;
        disp_result(os,result);
    }

private:
    static void out_name(std::ostream&os,const std::vector<uint8_t>& result,uint16_t& pos)
    {
        uint8_t sz=result[pos++];
        for (;sz;)
        {
            if (sz>=64)
            {
                if ((sz>>6)!=3)
                throw "label size exceed 64 character and not pointer, see RFC 1035";
                uint16_t offset=(((uint16_t)(sz&63))<<8)|result[pos++];
                out_name(os,result,offset);//recursion, hope not too deep
                return;
            }
            uint16_t epos=pos+sz;
            while (pos<epos)
                os << result[pos++];
            sz=result[pos++];
            if (sz)
            os << ".";
        }
    }



    static void disp_rr(std::ostream&os,const std::vector<uint8_t>& result,uint16_t& pos)
    {
        if (pos>=result.size())
            return;
        if (result[pos]==0)
            return;
        out_name(os,result,pos);
        if (((size_t)pos+10)>=(size_t)result.size())
        {
            throw "wrong packet size1";
        }
        uint16_t rr_type=tools::be2uint16(&result[pos]);
        os << "  (" ;
        os << "type:" ; disp_type(os,rr_type) ;
        pos+=2;
        os << "  class:" << tools::be2uint16(&result[pos]) ;
        pos+=2;
        os << "  TTL:" << tools::be2uint32(&result[pos]) << " secs)  ";
        pos+=4;
        uint16_t len=tools::be2uint16(&result[pos]);
        pos+=2;

        if (pos+len>result.size())
            throw "wrong packet size2";

        if (rr_type==1)
        {
            if (len!=4)
            {
                throw "unknown data size";
            }
            os << (int)result[pos] << "." << (int)result[pos+1] << "." << (int)result[pos+2] << "." << (int)result[pos+3];
        } else if (rr_type==2)
        {
            uint16_t pos2=pos;
            out_name(os,result,pos2);
        } else
        {
            throw "unimplement type";
        }
        pos+=len;

    }

    static void split_rr(const std::vector<uint8_t>& result,
                uint16_t& pos,
                std::string& name,
                uint16_t& rr_type,
                uint16_t& rr_class,
                uint32_t& TTL,
                uint32_t& IP4,
                std::string* pname2=nullptr
                )
    {
        if (pos>=result.size())
            return;
        if (result[pos]==0)
            return;
        std::stringstream ss;
        out_name(ss,result,pos);
        name=ss.str();
        if (((size_t)pos+10)>=(size_t)result.size())
        {
            throw "wrong packet size1";
        }
        rr_type=tools::be2uint16(&result[pos]);
        pos+=2;
        rr_class=tools::be2uint16(&result[pos]) ;
        pos+=2;
        TTL=tools::be2uint32(&result[pos]);
        pos+=4;
        uint16_t len=tools::be2uint16(&result[pos]);
        pos+=2;

        if (pos+len>result.size())
            throw "wrong packet size2";

        if (rr_type==1)
        {
            if (len!=4)
            {
                throw "unknown data size";
            }
            IP4=tools::be2uint32(&result[pos]); //*(uint32_t*)&result[pos];
        } else if (rr_type==2)
        {
            uint16_t pos2=pos;
            std::stringstream ss;
            out_name(ss,result,pos2);
            if (pname2)
            *pname2=ss.str();
        } else
        {
            throw "unimplement type";
        }
        pos+=len;

    }

    static int check_A(const std::vector<uint8_t>& result)
    {
        if (result.size()<12)
            return -1;

        uint8_t response=(result[2]>>7);
        if (response!=1)
            return -2;

        uint8_t rcode=result[3]&0xf;
        if (rcode!=0)
            return -3;

        uint8_t opcode=(result[2]>>3)&0xF;
        if (opcode!=0)
            return -4;

        return 0;
    }

    static int answer_A(const std::vector<uint8_t>& result,
                std::string& q_name,
                std::vector<uint32_t>& IP4s,
                std::vector<uint32_t>* pTTLs=nullptr)
    {
        int ret=check_A(result);
        if (ret<0)
            return ret;

        std::string name;
        uint16_t q_type=0;
        uint16_t q_class=0;
        uint16_t rr_type=0;
        uint16_t rr_class=0;
        uint32_t TTL=0;
        uint32_t IP4=0;
        uint16_t qc=tools::be2uint16(&result[4]);
        uint16_t pos=12;
        for (int i=0;i<qc;i++)//qc =1 usually
        {
            std::stringstream ss;
            out_name(ss,result,pos);
            q_name=ss.str();
            q_type=tools::be2uint16(&result[pos]);
            pos+=2;
            q_class=tools::be2uint16(&result[pos]);
            pos+=2;
        }
        uint16_t ac=tools::be2uint16(&result[6]);
        IP4s.resize(ac);
        if (pTTLs)
        pTTLs->resize(ac);
        for (int i=0;i<ac;i++)
        {
        split_rr(result,pos,name,rr_type,rr_class,TTL,IP4);
        if (name!=q_name)
            break;
        if (rr_type!=q_type)
            break;
        if (rr_class!=q_class)
            break;
        IP4s[i]=IP4;
        if (pTTLs)
            (*pTTLs)[i]=TTL;
        }
        return ac;
    }

    static void disp_result(std::ostream&os,const std::vector<uint8_t>& result)
    {
        if (result.size()<12)
            return;

        os << std::string((result[2]&0x80) ? "response" : "query") << std::endl;
        os << "size=" << result.size() << " bytes" << std::endl;
        uint8_t opcode=(result[2]>>3)&0xF;
        os << "Opcode:" << (int)opcode;
        if (opcode<=2)
            os << "  " << opcode_msg[opcode];
        os << std::endl;

        os << ((result[2]&1) ? "recursion asked" : "recursion NOT asked") << std::endl;
        if (result[2]&2)
        os << "response truncated" << std::endl;
        if (result[2]&4)
            os << "Authoritative answer" << std::endl;
        os << ((result[3]&0x80) ? "recursion available" : "recursion NOT available") << std::endl;
        uint8_t rcode=result[3]&0xf;
        os << "rcode:" << (int)rcode;
        if (rcode<=5)
            os << "  " << rcode_msg[rcode];
        os << std::endl;

        uint16_t qc=tools::be2uint16(&result[4]);
        os << "Query Count:" << qc << std::endl;
        uint16_t ac=tools::be2uint16(&result[6]);
        os << "Answer Count:" << ac << std::endl;
        uint16_t nc=tools::be2uint16(&result[8]);
        os << "Authoritative Name Server Count:" << nc << std::endl;
        uint16_t arc=tools::be2uint16(&result[10]);

        os << "Additional resource records Count:" << arc << std::endl;

        uint16_t pos=12;
        os << "_____________" << std::endl;
        os << "Query:" << std::endl;
        for (int i=0;i<qc;i++)//qc =1 usually
        {
            out_name(os,result,pos);
            os << "  (";
            os << "type:"; disp_type(os,tools::be2uint16(&result[pos]));
            pos+=2;
            os << " class:" << tools::be2uint16(&result[pos]) << " ) ";
            pos+=2;
            os << std::endl;
        }

        //answer
        if (ac)
        {
            os << "_________________________" << std::endl;
            os << "Answer:" << std::endl;
        }
        for (int i=0;i<ac;i++)
        {
            disp_rr(os,result,pos);
            os << std::endl;
        }
        if (nc)
        {
            os << "_________________________" << std::endl;
            os << "Authoritative nameservers:" << std::endl;
        }
        for (int i=0;i<nc;i++)
        {
            disp_rr(os,result,pos);
            os << std::endl;
        }
        if (arc)
        {
            os << "_________________________" << std::endl;
            os << "Additional resource records:" << std::endl;
        }
        for (int i=0;i<arc;i++)
        {
            disp_rr(os,result,pos);
            os << std::endl;
        }

    }


    size_t udp_request(const boost::asio::const_buffer& request,
                    const boost::asio::mutable_buffer& response,
                    const std::string_view& destination_ip,
                    const unsigned short port,
                    boost::system::error_code& ec)
    {
        using namespace boost;
        asio::ip::udp::socket socket(io_context_);
        auto remote = asio::ip::udp::endpoint(asio::ip::make_address(destination_ip), port);
        socket.open(boost::asio::ip::udp::v4());
        size_t sent=socket.send_to(request, remote);
        if (request.size()!=sent)
            return 0;

        return receive_from(socket,response,timeout_,ec);
    }

    std::size_t receive_from(
        boost::asio::ip::udp::socket& sock,
        const boost::asio::mutable_buffer& buffer,
        boost::asio::chrono::steady_clock::duration timeout,
        boost::system::error_code& ec)
    {
        std::size_t length = 0;
        sock.async_receive(boost::asio::buffer(buffer),
                [&](const boost::system::error_code& ec1,std::size_t sz){
                    ec=ec1;
                    length=sz;
                }
        );

        run(sock,timeout);

        return length;
    }

    void run(boost::asio::ip::udp::socket& sock,boost::asio::chrono::steady_clock::duration timeout)
    {
        // Restart the io_context, as it may have been left in the "stopped" state
        // by a previous operation.
        io_context_.restart();

        // Block until the asynchronous operation has completed, or timed out. If
        // the pending asynchronous operation is a composed operation, the deadline
        // applies to the entire operation, rather than individual operations on
        // the socket.
        io_context_.run_for(timeout);

        // If the asynchronous operation completed successfully then the io_context
        // would have been stopped due to running out of work. If it was not
        // stopped, then the io_context::run_for call must have timed out.
        if (!io_context_.stopped())
        {
        // Cancel the outstanding asynchronous operation.
        sock.cancel();

        // Run the io_context again until the operation completes.
        io_context_.run();
        }
    }

    static std::vector<uint8_t> make_dns_request(const std::string_view& hostname)
    {
        static uint16_t id=257;
        id++;
        std::vector<uint8_t> req;
        uint16_t epos=12+1+hostname.size();
        req.resize(epos+1+4);//plus null plus type (2 bytes) plus class (2 bytes)
        *((uint16_t*)&req[0])=id;
        req[2]=1;//recursive
        req[3]=32;//??, linux has 32, but 0 works too
        req[5]=1;//one query
        req[epos]=0;//end of string
        req[epos+2]=1;//type A - host address
        req[epos+4]=1;//class INT
        uint16_t cnt=0;
        for(int i=hostname.size()-1;i>=0;i--)
        {
            req[13+i]=(hostname[i]=='.') ? cnt : hostname[i];
            cnt=(hostname[i]=='.') ? 0 : cnt+1;
        }
        req[12]=cnt;
        return req;
    }


    std::string_view  resolve(const std::string_view& hostname,std::vector<uint8_t>& result)
    {
        auto myrequest=make_dns_request(hostname);

        return tools::find_first_success(nameservers_,
                [&](const std::string_view& nameserver)
                {
                    size_t sz=0;
                    try
                    {
                        boost::system::error_code ec;
                        sz=udp_request(boost::asio::buffer(myrequest),
                                boost::asio::mutable_buffer(&result[0],result.size()),
                                nameserver,53,ec);
                        if (ec.value()!=boost::system::errc::success)
                        {
                            std::cerr << "skipping due to error code not a success : error code=" << ec.value() << "  " << ec.message() << std::endl;
                            return false;
                        }
                        if (sz<35)//min packet , 12 header+3 name+4 type/class+ 16 (rr type 1)
                        {
                            std::cerr << "skipping because packet size < 35" << std::endl;
                            return false;
                        }
                        if ((myrequest[0]!=result[0])||(myrequest[1]!=result[1]))
                        {
                            std::cerr << "skipping because ID respond not matching ID request " << std::endl;
                            return false;
                        }
                        if (check_A(result)<0)
                        {
                            std::cerr << "skipping because packet check failed " << std::endl;
                            return false;
                        }
                    }
                    catch(std::exception& ex)
                    {
                        std::cerr << "Skipping because of exception:" << ex.what() << std::endl;
                        return false;
                    }
                    catch(...)
                    {
                        std::cerr << "Skipping because of unknown exception:" << std::endl;
                        return false;
                    }
                    result.resize(sz);
                    return true;

                });
    }



private:
    boost::asio::io_context io_context_;
    std::string nameservers_;
    boost::asio::chrono::milliseconds timeout_;
};

}

#endif /* DNSRESOLVER_H_ */
0
Marian K. On
/*
* main.cpp
*
*  Created on: Jan 12, 2023
*      Author: marian
*/
#include <iostream>
#include <vector>
#include "DNSResolver.h"
#define check_offset if (offset==argc) { print_error(); return 1;}

int main (int argc, char **argv) {
    auto print_error=[&]()
        {
        std::cerr << "Usage: " << argv[0] << " [-v] [-t timeout_msecs] [--nottl] [--nons] <query> [csv_ip_nameserver_list] \n";
        std::cerr << "Example: " << argv[0] << " bbc.co.uk 127.0.0.53,8.8.8.8" << std::endl;
        std::cerr << "Example: " << argv[0] << " -v bbc.co.uk" << std::endl;
        };
    if (argc < 2) {
        print_error();
        return 1;
    }
    int offset=1;
    bool verbose= (std::string(argv[offset])=="-v");
    if (verbose)
        offset++;
    check_offset;
    int millisecs=3000;
    if (std::string(argv[offset])=="-t")
    {
        offset++;
        check_offset;
        millisecs=std::stoi(argv[offset++]);
    }
    check_offset;

    bool ttl=true;
    if (std::string(argv[offset])=="--nottl")
    {
        ttl=false;
        offset++;
    }
    check_offset;

    bool ns=true;
    if (std::string(argv[offset])=="--nons")
    {
        ns=false;
        offset++;
    }
    check_offset;

    const char *dname=argv[offset];
    const char *nameservers= (offset+1<argc) ? argv[offset+1] : "127.0.0.53,8.8.8.8";

    if (ns)
    std::cout << "nameservers:" << nameservers << std::endl;

    DNS::resolver resolver{nameservers,boost::asio::chrono::milliseconds{millisecs}};

    if (verbose)
    {
        std::string_view nameserver_used;
        resolver.resolve_ostream(dname,nameserver_used,std::cout);
        return 0;
    }

    std::vector<uint32_t> TTLs;
    std::string_view nameserver_used;
    std::vector<uint32_t> IP4s=resolver.resolve(dname,nameserver_used,&TTLs);
    if (ns&&(!nameserver_used.empty()))
    std::cout << "nameserver:" << nameserver_used << std::endl;
   int n=IP4s.size();
    if (!n)
    {
        std::cerr << "Error: No IPs found" << std::endl;
        return 2;
    }

    for (int i=0;i<n;i++)
    {
        tools::out_IP4(std::cout,IP4s[i]);
        if (ttl)
          std::cout << "  (TTL:" << TTLs[i] << " sec)" ;
        std::cout << std::endl;
    }

    return 0;

}