#!/usr/bin/env python3
#
# Run a test.  Just the test spec is provided on stdin.
#

import ipaddress
import pscheduler
import random
import re

random.seed()

log = pscheduler.Log(prefix='tool-nmapreach', quiet=True)

input = pscheduler.json_load(exit_on_error=True);

spec = input["test"]["spec"]

timeout_iso = spec.get("timeout", "PT3S" )
timeout = pscheduler.timedelta_as_seconds( pscheduler.iso8601_as_timedelta(timeout_iso) )


# ------------------------------------------------------------------------------

def __addr_yields(network, limit):
    addresses = network.num_addresses - 2
    yields = min(limit, addresses)
    return (addresses, yields)


# Generator functions for the scan modes

def net_up(network, limit=None, exclude=None):
    (addresses, yields) = __addr_yields(network, limit)
    address = 1
    while addresses and yields:
        addresses -= 1
        address += 1

        to_yield = network[address]
        if to_yield == exclude:
            continue

        yield str(to_yield)
        yields -= 1


def net_down(network, limit=None, exclude=None):
    (addresses, yields) = __addr_yields(network, limit)
    address = network.num_addresses-2
    while addresses and yields:
        addresses -= 1
        to_yield = network[address]
        address -= 1

        if to_yield == exclude:
            continue

        yield str(to_yield)
        yields -= 1


def net_edges(network, limit=None, exclude=None):
    (addresses, yields) = __addr_yields(network, limit)
    lower_address = 1
    upper_address = -2
    use_lower = True

    while addresses and yields:
        addresses -= 1
        if use_lower:
            to_yield = network[lower_address]
            lower_address += 1
        else:
            to_yield = network[upper_address]
            upper_address -= 1

        if to_yield == exclude:
            continue

        yield str(to_yield)
        yields -= 1
        use_lower = not use_lower


def net_random(network, limit=None, exclude=None):
    (addresses, yields) = __addr_yields(network, limit)

    net_upper = addresses

    while yields:
        address = random.randrange(1, net_upper)
        to_yield = network[address]

        if to_yield == exclude:
            continue

        addresses -= 1
        yield str(to_yield)
        yields -= 1


generators = {
    "up": net_up,
    "down": net_down,
    "edges": net_edges,
    "random": net_random
};

# ------------------------------------------------------------------------------

# Perform the test

runs = 0

network = ipaddress.ip_network(spec["network"])

network_scannable = network.num_addresses - 2
run_max_ips = min(spec.get("limit", network_scannable), network_scannable)
log.debug("Limiting scan to %d addresses", run_max_ips)

parallel = spec.get("parallel", 100)


ip_list = []
generator = generators[spec.get("scan", "edges")]


spec_gateway = spec.get("gateway", None)
if isinstance(spec_gateway, int):
    gateway = network[spec_gateway]
elif isinstance(spec_gateway, str):
    gateway = str(ipaddress.ip_address(spec_gateway))
    ip_list.append(gateway)
else:
    gateway = None


nmap_first_args = [ 'nmap', '-n', '-sP', '-oG', '-', '--host-timeout', str(timeout) ]

# The lines we want look like this:
# Host: 192.168.1.1 ()       Status: Up
nmap_output_matcher = re.compile(r'^Host:\s*([^\s]+)\s.*Status:\s*Up')

nmap_timeout = timeout * 1.5



def run_nmap(ips, run_number, gateway):
    """Execute nmap and determine what was up"""

    diags = []

    args = nmap_first_args + ips
    diags.append(' '.join(args))
    (status, out, err) = pscheduler.run_program(args, timeout=nmap_timeout)

    if status != 0:
        pscheduler.succeed_json({
            "succeeded": False,
            "diags": '',
            "error": "Nmap failed: %s" % (err)
        })

    diags.append("")
    diags.append(out)

    ips_up = [
        m.group(1)
        for m in [
                m for m in [nmap_output_matcher.match(l) for l in out.split("\n")]
                if m is not None
        ]
    ]

    log.debug("Found these IPs up: %s", ips_up)

    # In the first run, we check to see if the first thing we get back
    # is the gateway.
    if run_number == 0 and ips_up and ips_up[0] == gateway:
        gateway_up = True
        del ips_up[0]
    else:
        gateway_up = False

    return (gateway_up, bool(ips_up), diags)



def succeed(spec, gateway_was_up, network_was_up, diags):
    """Exit with results"""

    result = {
        "succeeded": True,
        "diags": diags or "No diagnostics.",
        "result": {
            "succeeded": True,
            "network-up": network_was_up
        }
    }

    if "gateway" in spec:
        result["result"]["gateway-up"] = gateway_was_up

    pscheduler.succeed_json(result)



gateway_up = False
run_number = 0
diags = []

try:
    start_at = input['schedule']['start']
    log.debug("Sleeping until %s", start_at)
    pscheduler.sleep_until(start_at)
    log.debug("Starting")
except KeyError:
    pscheduler.fail("Unable to find start time in input")

for next_ip in generator(network, exclude=gateway, limit=run_max_ips):

    if len(ip_list) >= parallel:
        log.debug("Batch %d IPs: %s", run_number, ip_list)
        (gateway_up, network_up, nmap_diags) = run_nmap(ip_list, run_number, gateway)
        diags += nmap_diags
        if network_up:
            succeed(spec, gateway_up, True, diags)
        run_number += 1
        ip_list = []

    ip_list.append(next_ip)

# Do another nmap if there's anything left in the array
if ip_list:
    log.debug("Batch %d IPs: %s", run_number, ip_list)
    (gateway_up, network_up, nmap_diags) = run_nmap(ip_list, run_number, gateway)
    diags += nmap_diags
    if network_up:
            succeed(spec, gateway_up, True, diags)


# If nothing else happened, nothing on the network is up.

succeed(spec, gateway_up, False)

