#!/usr/bin/env python3
# 20250423

# fakedns is a DNS server for testing purposes.
# Returns A response '127.0.0.1' for every query.

import socket
import logging
import argparse
import select
import struct
import sys
from dnslib import *

class DNSServer:
    '''
    '''

    def __init__(self, ip, port):
        '''
        '''

        # UDP
        self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.udp.bind((ip, port))
        logging.debug(f'bind UDP %s:%s' % (ip, port))

        # TCP
        self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.tcp.bind((ip, port))
        self.tcp.listen()
        logging.debug(f'bind/listen TCP %s:%s' % (ip, port))

        logging.info(f'running DNSServer %s:%s' % (ip, port))


    def packet(self, data, flagtcp):
        '''
        '''

        query = DNSRecord.parse(data)
        qname = query.q.qname
        answer = query.reply()
        ip='127.0.0.1'
        answer.add_answer(RR(qname,QTYPE.A,rdata=A(ip),ttl=60))
        logging.info('query: %s, answer: %s' % (qname, ip))
        return answer.pack()

    def run(self):
        '''
        '''

        while True:
            # read DNS query
            try:
                (rlist, _, _) = select.select([self.udp, self.tcp], [], [])
                for r in rlist:
                    if r == self.udp:
                        flagtcp = False
                        packet, addr = r.recvfrom(4096)
                    if r == self.tcp:
                        flagtcp = True
                        c, addr = r.accept()
                        # read the length
                        packet = c.recv(1)
                        packet += c.recv(1)
                        length = struct.unpack("!H",bytes(packet))[0]
                        # read the packet
                        packet = b''
                        for i in range(length):
                            packet += c.recv(1)

                    answer = self.packet(packet, flagtcp)

                    # send response to the client
                    if r == self.udp:
                        # UDP
                        r.sendto(answer, addr)
                    if r == self.tcp:
                        # TCP
                        answer = struct.pack("!H",len(answer)) + answer
                        c.sendall(answer, len(answer))
                        c.close()

            except KeyboardInterrupt:
                break
            except Exception as e:
                logging.fatal('%s' % (e))


if __name__ == '__main__':

    # argparser
    parser = argparse.ArgumentParser(description='fakedns.py')
    parser.add_argument('-v', '--verbose',
                        action='count',
                        help='verbosity',
                        default=0)
    parser.add_argument('IP', type=str, help='DNS IP address')
    parser.add_argument('PORT', type=int, help='DNS port')
    args = parser.parse_args()

    # verbosity
    LOG_LEVELS = ["INFO", "DEBUG"]
    if (args.verbose > len(LOG_LEVELS) - 1):
        args.verbose = len(LOG_LEVELS) - 1

    # IP/PORT
    ip = args.IP
    port = args.PORT

    # log
    logging.basicConfig(level=LOG_LEVELS[args.verbose])

    # run fake DNS server
    dnsserver = DNSServer(ip, port)
    dnsserver.run()
