#!/usr/bin/python3
#
# Run a test.  Just the test spec is provided on stdin.
#


import datetime
import icmperror
import pscheduler
import re

log = pscheduler.Log(prefix='tool-tcpping', quiet=True)

input = pscheduler.json_load(exit_on_error=True);

# TODO: Validate the input

participant = input['participant']

if participant != 0:
    pscheduler.fail("Invalid participant.")

spec = input['test']['spec']


run_timeout = datetime.timedelta()


#
# Figure out how to invoke the program
#

argv = [
    'tcpping',
]


#
# Get address related parameters
#
## Source
source = spec.get('source', None)

## Destination (required)
try:
    dest = spec['dest']
except KeyError:
    pscheduler.fail("Dest is a required parameter")

## IP Version
ipversion = spec.get('ip-version', None)

# Lookup ip address so ping does not have to and determine ip version

source_ip = None
dest_ip = None

if source is not None:
    (source_ip, dest_ip) = pscheduler.ip_normalize_version(source, dest, ip_version=ipversion)
    if source_ip is None and dest_ip is None:
        pscheduler.fail("Can't find same-family IPs for %s and %s." % (source, dest))
    if dest_ip is not None:
        (ipversion, dest_ip) = pscheduler.ip_addr_version(dest_ip)
elif ipversion is None:
    (ipversion, dest_ip) = pscheduler.ip_addr_version(dest)
    if ipversion is None:
        dest_ip = None

# Source

if source_ip is not None:
    argv.append('--s')
    argv.append(source_ip)

# Destination

# Make a best guess at the IP if we didn't find one above.
if dest_ip is None:
    dest_ip_map = pscheduler.dns_bulk_resolve([dest], ip_version=ipversion)
    # For failed resolution, just try the hostname and let ping deal with it.
    dest_ip = dest_ip_map.get(dest, dest)

argv.append('--d')
argv.append(dest_ip)


# IP Version

if ipversion is not None:
    argv.append('--v')
    argv.append(ipversion)

# Port

if 'port' in spec:
    argv.append('--p')
    argv.append(spec['port'])


# Count

count = spec.get('count', 5)
argv.append('--c')
argv.append(count)
run_timeout += count * datetime.timedelta(seconds=2)



#
# Run the test
#

# Stringify the arguments
argv = [str(x) for x in argv]

argv_string =  ' '.join(argv)
log.debug("Running %s", argv_string)

# Add some run slop
run_timeout += datetime.timedelta(seconds=2)
log.debug("Timeout is %s", run_timeout)

run_timeout_secs = pscheduler.timedelta_as_seconds(run_timeout)

try:
    start_at = input['schedule']['start']
    log.debug("Sleeping until %s", start_at)
    pscheduler.sleep_until(start_at)
    log.debug("Starting")
except KeyError:
    pscheduler.fail("Unable to find start time in input")

status, stdout, stderr \
    = pscheduler.run_program(argv, timeout=run_timeout_secs)

log.debug("Program exited %d: %s",
          status, stdout if status == 0 else stderr)

if status:
    pscheduler.succeed_json({
        'succeeded': False,
        'diags': argv_string,
        'error': stderr,
        'result': None
    })


#
# Dissect the results
#

final_result = {
    'schema': 1,
    'succeeded': True,
    'diags': argv_string + '\n\n' + stdout,
    'error': None
}

result = {
    'schema': 1
    }


#
# Parse the results
#

# Sample:
# $ tcpping --d 1.2.3.4 --c 3
# Connection timed out!
# Connection timed out!
# Connection timed out!
#
# TCP Ping Results: Connections (Total/Pass/Fail): [3/0/3] (Failed: 100.000%)
#
#
# $ tcpping --d 1.2.3.4 --c 3 --p 8080
# Connected to 198.6.1.1 (198.6.1.1) [53]: tcp_seq=0 time=32.67 ms
# Connected to 198.6.1.1 (198.6.1.1) [53]: tcp_seq=1 time=33.15 ms
# Connected to 198.6.1.1 (198.6.1.1) [53]: tcp_seq=2 time=32.66 ms
#
# TCP Ping Results: Connections (Total/Pass/Fail): [3/3/0] (Failed: 0%)
# rtt min/avg/max/mdev = 32.660/32.827/33.150/0.280 ms


# TCP Ping Results: Connections (Total/Pass/Fail): [3/3/0] (Failed: 0%)
CONNECTED = re.compile(r'^Connected to\s*[^\s]+'
                       r'\s+\(([^)]+)\)'
                       r'\s+.*:'
                       r'\s+tcp_seq=([0-9]+)'
                       r'\s+time=([0-9.]+) ms$')

# TCP Ping Results: Connections (Total/Pass/Fail): [3/3/0] (Failed: 0%)
COUNTED = re.compile(r'^TCP Ping Results: Connections \(Total/Pass/Fail\): \[([0-9]+)/([0-9]+)/([0-9]+)\]')

# rtt min/avg/max/mdev = 32.660/32.827/33.150/0.280 ms
STATS = re.compile(r'^rtt min/avg/max/mdev\s+=\s+'
                   r'([0-9.]+)/([0-9.]+)/([0-9.]+)/([0-9.]+)\s+ms')



endgame = False
result = {
    "schema": 1,
    "succeeded": True
}

seq = 0
roundtrips = []
ips = {}

for line in stdout.split('\n'):

    if line == '':
        endgame = True
        continue

    if not endgame:

        matches = CONNECTED.match(line)
        if matches is not None:
            if endgame:
                continue
            seq += 1
            ip, _seq, rtt = matches.groups()
            # TODO: seq and _seq should match.
            roundtrips.append({
                'seq': seq,
                'ip': ip,
                'rtt': 'PT%fS' % (float(rtt) / 1000.0)
            })
            ips[ip] = 1
            continue

        if line == 'Connection timed out!' or line.startswith("OS Error:"):
            seq += 1
            roundtrips.append({
                'seq': seq,
                # No IP in the output; use the best available info.
                'ip': dest_ip,
                'error': 'host-unreachable'
            })
            ips[dest_ip] = 1
            continue

        # Ignore anything else.

    else:

        # Next-to-last line
        matches = COUNTED.match(line)
        if matches is not None:
            total, passed, failed = [int(value) for value in matches.groups()]
            result['sent'] = total
            result['received'] = passed
            result['lost'] = failed
            result['loss'] = float(failed) / float(total)
            # This tool doesn't detect these.  Assume none.
            result['reorders'] = 0
            result['duplicates'] = 0

        # Last line
        matches = STATS.match(line)
        if matches is not None:
            rtt_min, rtt_avg, rtt_max, rtt_mdev = [
                "PT%fS" % (float(value) / 1000.0) for value in matches.groups()
            ]
            result['min'] = rtt_min
            result['max'] = rtt_max
            result['mean'] = rtt_avg
            result['stddev'] = rtt_mdev
            break

        # Ignore anything else.


# If we're doing hostnames, bulk-resolve them.

if spec.get('hostnames', True) and len(ips) > 0:
    log.debug("Reverse-resolving IPs: %s", str(ips))
    revmap = pscheduler.dns_bulk_resolve(ips, reverse=True, threads=len(roundtrips))
    for hop in roundtrips:
        try:
            ip = hop['ip']
            if ip in revmap and revmap[ip] is not None:
                hop.update({'hostname': revmap[ip]})
        except KeyError:
            # No IP is fine.
            pass

# Spit out the results

result['roundtrips'] = roundtrips
final_result['result'] = result
pscheduler.succeed_json(final_result)
