#!/usr/bin/env python3
# encoding=UTF-8

# Copyright © 2011-2022 Jakub Wilk <jwilk@jwilk.net>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the “Software”), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import argparse
import collections
import os
import re
import socket
import sys

# pylint: disable=multiple-statements
async def _(): return f'{await "# Python >= 3.7 is required #"}'
# pylint: enable=multiple-statements

__version__ = '0.2.1'

description = \
'''
DNS lookup with Graphviz output
'''

template_begin = '''
digraph {
    rankdir=LR
    edge [arrowsize=0.5, arrowhead="vee"]
    node [fontsize=10, width=0, height=0, shape=box]
'''

template_end = '''
}
'''

def _html_escape(match):
    c = ord(match.group(0))
    return f'&#{c};'

def dot_escape(s):
    # https://www.graphviz.org/doc/info/lang.html says:
    #
    #    In quoted strings […] the only escaped character is double-quote
    #    ("). That is, in quoted strings, the dyad \" is converted to "; all
    #    other characters are left unchanged. In particular, \\ remains \\.
    #
    # This makes little sense, and doesn't match how dot(1) actually works.
    # Oh well. Doubling the backslashes can't hurt much.
    s = s.replace('\\', '\\\\')
    s = re.sub('[&"]', _html_escape, s)
    return f'"{s}"'

class Mapping():

    def __init__(self, pending=()):
        self.pending = set(pending)
        self.initial = frozenset(pending)
        self.done = set()
        self.mapping = collections.defaultdict(set)

    def add(self, source, target):
        self.done.add(source)
        if source == target:
            return
        if target not in self.done:
            self.pending.add(target)
        self.mapping[source].add(target)

    def update_all(self):
        while self.pending:
            self.update()

    def update(self):
        address = self.pending.pop()
        try:
            host, _, ips = socket.gethostbyaddr(address)
            for ip in ips:
                self.add(ip, host)
        except (socket.gaierror, socket.herror):
            pass
        try:
            for _, _, _, _, ip in socket.getaddrinfo(address, 0):
                self.add(address, ip[0])
        except socket.gaierror:
            return
        self.done.add(address)

    def print(self):
        print(template_begin.strip())
        x = dot_escape
        for source in self.initial:
            if source in self.mapping:
                print(f'    {x(source)} [style=bold]')
        for source, targets in self.mapping.items():
            for target in targets:
                print(f'    {x(source)} -> {x(target)}')
        print(template_end.strip())

class VersionAction(argparse.Action):
    '''
    argparse --version action
    '''

    def __init__(self, option_strings, dest=argparse.SUPPRESS):
        super().__init__(
            option_strings=option_strings,
            dest=dest,
            nargs=0,
            help='show version information and exit'
        )

    def __call__(self, parser, namespace, values, option_string=None):
        print(f'{parser.prog} {__version__}')
        print('+ Python {0}.{1}.{2}'.format(*sys.version_info))  # pylint: disable=consider-using-f-string
        try:
            libc = os.confstr('CS_GNU_LIBC_VERSION')
        except (ValueError, OSError):
            pass
        else:
            print(f'+ {libc}')
        parser.exit()

def main():
    ap = argparse.ArgumentParser(description=description)
    ap.add_argument('--version', action=VersionAction)
    ap.add_argument('addresses', metavar='ADDRESS', nargs='+', help='IP address or domain name')
    options = ap.parse_args()
    mapping = Mapping(pending=options.addresses)
    mapping.update_all()
    sys.stdout.reconfigure(encoding='UTF-8')
    mapping.print()

if __name__ == '__main__':
    main()

# vim:ts=4 sts=4 sw=4 et
