#!/usr/bin/env python3
#
# Run an twping test
#

import datetime
import ipaddress
import json
import math
import pscheduler
import re
import tempfile
import shutil
import sys
import time
from twping_defaults import *
from twping_utils import CLIENT_ROLE, SERVER_ROLE, get_role, get_config

#track when this run starts
start_time = datetime.datetime.now()

#Init logging
log = pscheduler.Log(prefix="tool-twping", quiet=True)

#DEBUGGING: Set static values below
# participant = 0
# participant_data = [{}, {'ctrl-port': 11000, 'data-port-range': '11001-11002'}]
# test_spec = {'source': '10.0.1.28', 'dest': '10.0.1.25', "packet-timeout": 7}
# duration = pscheduler.iso8601_as_timedelta('PT15S')

#parse JSON input
input = pscheduler.json_load(exit_on_error=True)
try:
    participant = input['participant']
    participant_data = input['participant-data']
    test_spec = input['test']['spec']
    test_type = input['test']['type']
    start = pscheduler.iso8601_as_datetime(input['schedule']['start'])
    duration = pscheduler.iso8601_as_timedelta(input['schedule']['duration'])
except KeyError as ex:
    pscheduler.fail("Missing required key in run input: %s" % ex)
except:
    pscheduler.fail("Error parsing run input: %s" % sys.exc_info()[0])
reverse = test_spec.get('reverse', False)


# These are the only test types supported
assert test_type in ['latency', 'rtt'], 'Unsupported test type "%s"' % (test_type)


#constants
TIME_SCALE = .001 #use for millisecond conversions
DEFAULT_TWPING_CMD = '/usr/bin/twping'
DEFAULT_TWPING_LENGTH = 41
DEFAULT_BUCKET_WIDTH = TIME_SCALE #convert to ms
DELAY_BUCKET_DIGITS = 2 #number of digits to round delay buckets
DELAY_BUCKET_FORMAT = '%.2f' #Set buckets to nearest 2 decimal places
DEFAULT_RAW_OUTPUT = False #don't display raw packets by default
CLOCK_ERROR_DIGITS = 2 #number of digits to round clock error
TWPING_RANGE_ARGS = [
    ('data-ports', '-P'),
]
TWPING_VAL_ARGS = [
    ('ip-tos', '-D'),    
    ('packet-padding', '-s')
]

#determine whether we are the client(twping) or server(twampd)
role = get_role(participant, test_spec)

#run test
results = { 
    'schema': LATENCY_SCHEMA_VERSION, 
    'succeeded': False 
    }
if role == CLIENT_ROLE:
    #read config file
    config = get_config()
    twping_cmd = DEFAULT_TWPING_CMD
    if config and config.has_option(CONFIG_SECTION, CONFIG_OPT_TWPING_CMD):
        twping_cmd = config.get(CONFIG_SECTION, CONFIG_OPT_TWPING_CMD)
    
    #TODO: Get path to twping from config file
    #Always use raw output (-R)
    twping_args = [twping_cmd, '-R']

    #build basic arguments
    for arg in TWPING_VAL_ARGS:
        if arg[0] in test_spec:
            twping_args.append(arg[1])
            twping_args.append(str(test_spec[arg[0]]))
    for rarg in TWPING_RANGE_ARGS:
        if rarg[0] in test_spec:
            twping_args.append(rarg[1])
            twping_args.append("%d-%d" % (test_spec[rarg[0]]['lower'], test_spec[rarg[0]]['upper']))
            
    #set packet size for rtt tests to match ping
    rtt_length = DEFAULT_TWPING_LENGTH
    if 'length' in test_spec and test_spec['length'] > DEFAULT_TWPING_LENGTH:
        rtt_length = test_spec['length'] + 8
        twping_args.append('-s')
        twping_args.append(str(test_spec['length'] - 6))

    # Set interval,count and timeout to ensure consistent with
    # duration.  Note that the RTT test specifies its durations and
    # timeout as ISO8601 and latency uses decimal seconds.  Handle
    # that accordingly.

    twping_args.append('-c')
    twping_args.append(str(test_spec.get('packet-count', test_spec.get('count', DEFAULT_PACKET_COUNT))))

    try:
        if test_type == 'latency':  # Decimal seconds
            interval = float(test_spec['packet-interval'])
        else:  # RTT, ISO8601
            interval =  pscheduler.timedelta_as_seconds(pscheduler.iso8601_as_timedelta(test_spec['interval']))
    except KeyError:
        interval = DEFAULT_PACKET_INTERVAL
    twping_args += ['-i', str(interval) ]


    try:
        if test_type == 'latency':  # Decimal seconds
            timeout = float(test_spec['packet-timeout'])
        else:  # RTT, ISO8601
            timeout =  pscheduler.timedelta_as_seconds(pscheduler.iso8601_as_timedelta(test_spec['timeout']))
    except KeyError:
        timeout = DEFAULT_PACKET_TIMEOUT
    twping_args += ['-L', str(timeout) ]

    
    #set if ipv4 only or ipv6 only
    try:
        ip_version = int(test_spec.get('ip-version'))
    except TypeError:
        ip_version = None

    if ip_version == 4:
        twping_args.append('-4')
    elif ip_version == 6:
        twping_args.append('-6')

    #determine control port
    control_port = int(test_spec.get('ctrl-port', DEFAULT_TWAMPD_PORT))

    #calculate destination ip for rtt tests
    dest = test_spec['dest']

    try:
        addr = ipaddress.ip_address(dest)
        # It's an IP at this point
        dest_ip = dest
        dest_hostname = None
        if addr.version == 6:
            dest_arg = '[%s]:%d' % (dest, control_port)
        else:
            dest_arg = '%s:%d' % (dest, control_port)
    except ValueError:
        # Not an IP, assume a hostname
        dest_hostname = dest
        # None here is okay and will be sorted out later.
        dest_ip = pscheduler.dns_resolve(dest_hostname, ip_version=ip_version)
        dest_arg = '%s:%d' % (dest, control_port)
    
    #bucket width is used for rounding delay values used as buckets for histogram
    bucket_width = test_spec.get('bucket-width', DEFAULT_BUCKET_WIDTH)
    
    #determine whether we will return raw packets
    raw_output = test_spec.get('output-raw', DEFAULT_RAW_OUTPUT)
    


    if test_spec.get('traverse-nat', False):
        twping_args.append('-Z')

        
    #finally, set the addresses and packet flow direction
    if 'source' in test_spec:
        twping_args.append('-S')
        twping_args.append(test_spec['source'])

    # Set up the destination

    twping_args.append(dest_arg)


    # Force all args to be strings
    twping_args = [str(x) for x in twping_args]

    log.debug("Sleeping until %s", start)
    pscheduler.sleep_until(start)
    log.debug("Starting")

    pscheduler.sleep_until(start)
        
    #Run the process
    #time.sleep(DEFAULT_CLIENT_SLEEP) #wait for server to boot
    log.debug("Running command: %s" % " ".join(twping_args))
    try:
        returncode, stdout, stderr = pscheduler.run_program(twping_args)
    except OSError as ex:
        log.error("twping encountered an OS error: %s" % ex)
        #Note: Avoids reporting sensitive system details in error message
        results['error'] = "The twping command failed during execution. See server logs for more details."
        pscheduler.succeed_json(results)
    except Exception:
        log.error("twping failed to complete execution: %s" % sys.exc_info()[0])
        results['error'] = "The twping command failed during execution. See server logs for more details."
        pscheduler.succeed_json(results)
        
    #see if command completed successfully
    log.debug("twping returned status %d" % returncode)
    if returncode:
        if stderr:
            twp_error = stderr.strip().replace("\n", ";")
            results['error'] = "twping returned an error: %s" % twp_error
        else:
            results['error'] = "twping returned an error status but no message"
        pscheduler.succeed_json(results)
        
    # RAW ascii format for TWAMP is:
    # "SSEQ STIME SS SERR SRTIME SRS SRERR STTL RSEQ RSTIME RSS RSERR RTIME RS RERR RTTL\n"
    # name     desc                           type
    # SSEQ     send sequence number           unsigned long
    # STIME    sendtime                       owptimestamp (%020 PRIu64)
    # SS       send synchronized              boolean unsigned
    # SERR     send err estimate              float (%g)
    # SRTIME   send (receive) time            owptimestamp (%020 PRIu64)
    # SRS      send (receive) synchronized    boolean unsigned
    # SRERR    send (receive) err estimate    float (%g)
    # STTL     send ttl                       unsigned short
    # RSEQ     reflect sequence number        unsigned long
    # RTIME    reflected (send) time          owptimestamp (%020 PRIu64)
    # RS       reflected (send) synchronized  boolean unsigned
    # RERR     reflected (send) err estimate  float (%g)
    # RTIME    recvtime                       owptimestamp (%020 PRIu64)
    # RS       recv synchronized              boolean unsigned
    # RERR     recv err estimate              float (%g)
    # RTTL     reflected ttl                  unsigned short
    twping_regex = re.compile(r'^(\d+) (\d+) (\d) ([-.0-9e+]*) (\d+) (\d) ([-.0-9e+]*) (\d+) (\d+) (\d+) (\d) ([-.0-9e+]*) (\d+) (\d) ([-.0-9e+]*) (\d+)$')
    results['packets-sent'] = 0
    results['packets-received'] = 0
    results['packets-duplicated'] = 0
    results['packets-reordered'] = 0
    results['histogram-latency'] = {}
    results['histogram-ttl'] = {}
    if raw_output:
        results['raw-packets'] = []
    packets_seen = {}
    prev_seq_number = -1
    stdout_lines = []
    roundtrips = []
    rtts = []
    if stdout:
        stdout_lines = stdout.split("\n")
    for line in stdout_lines:
        twping_match = twping_regex.match(line)
        if twping_match:
            if reverse:
                seq_number               = twping_match.group(9);
                source_timestamp         = twping_match.group(10);
                source_synchronized      = twping_match.group(11);
                source_error             = twping_match.group(12);
                destination_timestamp    = twping_match.group(13);
                destination_synchronized = twping_match.group(14);
                destination_error        = twping_match.group(15);
                ttl                      = twping_match.group(16);
            else:
                seq_number               = twping_match.group(1);
                source_timestamp         = twping_match.group(2);
                source_synchronized      = twping_match.group(3);
                source_error             = twping_match.group(4);
                destination_timestamp    = twping_match.group(5);
                destination_synchronized = twping_match.group(6);
                destination_error        = twping_match.group(7);
                ttl                      = twping_match.group(8);

            #collect rtt data
            ts1 = float(twping_match.group(2))
            ts2 = float(twping_match.group(13))
            if ts1 != 0 and ts2 != 0:
                rtt = (ts2 - ts1) / pow(2, 32)
                rtts.append(rtt)
                rt = {
                    'seq': int(seq_number) + 1,
                    'length': int(rtt_length),
                    'ttl': int(ttl),
                    'rtt': pscheduler.timedelta_as_iso8601(datetime.timedelta(seconds=rtt))
                }
                if dest_hostname is not None:
                    rt['hostname'] = dest_hostname
                if dest_ip is not None:
                    rt['ip'] = dest_ip
                roundtrips.append(rt)

            #publish raw pings
            if raw_output:
                #init packet object
                packet = {
                    'seq-num': int(seq_number),
                    'src-ts': int(source_timestamp),
                    'src-clock-sync': source_synchronized == "1",
                    'src-clock-err': source_error,
                    'dst-ts': int(destination_timestamp),
                    'dst-clock-sync': destination_synchronized == "1",
                    'dst-clock-err': destination_error,
                    'ip-ttl': int(ttl)
                }
                #The error values may not exist, but format as floats if they do
                if source_error:
                    packet['src-clock-err'] = float(source_error)
                if destination_error:
                    packet['dst-clock-err'] = float(destination_error)
                #add to list
                results['raw-packets'].append(packet)
    
            #duplicates
            if seq_number in packets_seen:
                results['packets-duplicated'] += 1
                continue
            
            #sent
            results['packets-sent'] += 1
            
            #packet lost
            if int(destination_timestamp) == 0:
                continue
            
            #received
            results['packets-received'] += 1
            
            #delay histogram
            ##calculate delay in terms of seconds. TWAMP uses odd timestamps so need the divide by 2 ^ 32
            delay = (float(destination_timestamp) - float(source_timestamp)) / pow(2, 32)
            #round and add 0 to prevent -0.00
            delay_bucket = DELAY_BUCKET_FORMAT%(round(delay/bucket_width, DELAY_BUCKET_DIGITS) + 0)
            if delay_bucket in results['histogram-latency']:
                results['histogram-latency'][delay_bucket] += 1
            else:
                results['histogram-latency'][delay_bucket] = 1
                
            #TTL histogram
            if ttl in results['histogram-ttl']:
                results['histogram-ttl'][ttl] += 1
            else:
                results['histogram-ttl'][ttl] = 1
            
            #reorders
            if int(seq_number) < int(prev_seq_number):
                results['packets-reordered'] += 1
            prev_seq_number = seq_number
            
            #clock error
            if source_error and destination_error:
                clock_error = float(source_error) + float(destination_error)
                if ('max-clock-error' not in results) or (clock_error > results['max-clock-error']):
                    results['max-clock-error'] = clock_error

    #add packets lost for convenience
    results['packets-lost'] = results['packets-sent'] - results['packets-received']
    
    #convert clock error to ms
    if ('max-clock-error' in results):
        results['max-clock-error'] = round(results['max-clock-error']/TIME_SCALE, CLOCK_ERROR_DIGITS)
    
    #convert results for rtt schema
    if input['test']['type'] == 'rtt':
        rtt_results = {
            'schema': 1,
            'sent': results['packets-sent'],
            'received': results['packets-received'],
            'lost': results['packets-lost'],
            'reorders': results['packets-reordered'],
            'duplicates': results['packets-duplicated'],
            'roundtrips': roundtrips
        }
        results = rtt_results

        # calculate stats
        results['loss'] = 0
        if results['sent'] > 0:
            results['loss'] = float(results['sent']-results['received'])/float(results['sent'])

        if len(rtts) > 0:
            rtts.sort()
            results['min'] = pscheduler.timedelta_as_iso8601(datetime.timedelta(seconds=rtts[0]))
            results['max'] = pscheduler.timedelta_as_iso8601(datetime.timedelta(seconds=rtts[-1]))

            mean = sum(rtts) / len(rtts)
            results['mean'] = pscheduler.timedelta_as_iso8601(datetime.timedelta(seconds=mean))

            stddev = 0.0
            if len(rtts) > 1:
                stddev = math.sqrt(sum((rtt-mean)**2 for rtt in rtts) / (len(rtts) - 1))
            results['stddev'] = pscheduler.timedelta_as_iso8601(datetime.timedelta(seconds=stddev))

    results['succeeded'] = True
elif role == SERVER_ROLE:
#######
## Code to spawn a server. Unfortunately can only use once we have true resource management
#     #get and validate participant data
#     try:
#         recv_participant_data = participant_data[1]
#     except IndexError:
#         pscheduler.fail("Unable to find participant data for dest")
#     if 'ctrl-port' not in recv_participant_data:
#         pscheduler.fail("Unable to find control port in participant data")
#     
#     if 'data-port-range' not in recv_participant_data:
#         pscheduler.fail("Unable to find data port range in participant data")
#     
#     #ctrl-port must be an integer
#     try:
#         int(recv_participant_data['ctrl-port'])
#     except ValueError:
#         pscheduler.fail("Control port must be an integer")
#     
#     #data-port-range must be in form N-M where N < M and both are integers
#     range_match = re.compile(r'(\d+)-(\d+)').match(recv_participant_data['data-port-range'])
#     if not range_match:
#         pscheduler.fail("Data port range is not a valid range. Must be in form N-M where N < M and both are integers")
#     elif range_match.group(1) >= range_match.group(2):
#         pscheduler.fail("Data port range is not a valid range. The first value must be less than the second")
#     
#     #init command
#     #TODO: Get command path from config file
#     twampd_args = ['/usr/bin/twampd', '-Z' ]
#     twampd_args.append('-f') #TODO: Remove this when we don't run as root
#     
#     #set data port range
#     twampd_args.append('-P')
#     twampd_args.append(str(recv_participant_data['data-port-range']))
#     
#     #set listening address
#     if flip:
#         listen_addr = test_spec['source']
#     else:
#         listen_addr = test_spec['dest']
#     twampd_args.append('-S')
#     twampd_args.append("[%s]:%d" % (listen_addr, recv_participant_data['ctrl-port']))
#     
#     #create temp dir for data and pid
#     tmpdir = tempfile.mkdtemp()
#     twampd_args.append('-R')
#     twampd_args.append(tmpdir)
#     twampd_args.append('-d')
#     twampd_args.append(tmpdir)
#     
#     #Run the process
#     log.debug("Running command: %s" % " ".join(twampd_args))
#     try:
#         proc = Popen(twampd_args,stdout=PIPE, cwd=tmpdir, stderr=PIPE, shell=False)
#     except OSError as e:
#         log.error("twampd encountered an OS error: %s" % e)
#         #Note: Avoids reporting sensitive system details in error message
#         pscheduler.fail("The twampd command failed during execution. See server logs for more details.")
#     except Exception:
#         log.error("twampd failed to complete execution: %s" % sys.exc_info()[0])
#         pscheduler.fail("The twampd command failed during execution. See server logs for more details.")
#     
#     #sleep while we wait for test to complete in allotted time
#     try:
#         time.sleep(pscheduler.timedelta_as_seconds(duration - (datetime.datetime.now() - start_time)))
#     except IOError as e:
#         log.error("Unable to sleep while server runs. Your test may be out of time")
#         log.debug("Sleep error: %s" % e)
#     
#     #check if server failed
#     if proc.poll():
#         #if we are here, the server died
#         twp_error = ''
#         for line in proc.stderr:
#             twp_error += line.rstrip().lstrip() + " "
#         pscheduler.fail("twampd returned an error: %s" % twp_error)
#     else:
#         #process is still running like it should be, kill it
#         proc.kill()
#     
#     #log stdout in debug mode
#     for line in proc.stdout:
#         log.debug(line)
#     
#     #Remove our tmpdir, but don't fail the test if it doesn't remove
#     try:
#         shutil.rmtree(tmpdir, ignore_errors=False)
#     except:
#         log.warn("Unable to remove twampd temporary directory %s: %s" % (tmpdir, sys.exc_info()[0]))
    
    results['succeeded'] = True

log.debug("Results: %s" % results)
pscheduler.succeed_json(results)
