#!/usr/bin/env python3
#
# Run an ethr test
#

import datetime
import pscheduler
import re
import sys
import time

import ethr_utils

from ethr_defaults import *

# track when this run starts
start_time = datetime.datetime.now()

logger = pscheduler.Log(prefix='tool-ethr', quiet=True)
logger.debug('starting ethr tool')

# parse JSON input
input = pscheduler.json_load(exit_on_error=True)

logger.debug('Input is %s' % input)

try:
    participant = input['participant']
    participant_data = input['participant-data']
    test_spec = input['test']['spec']    
except KeyError as ex:
    pscheduler.fail('Missing required key in run input: %s' % ex)
except Exception:
    pscheduler.fail('Error parsing run input: %s' % sys.exc_info()[0])

source = test_spec.get('source', None)
destination = test_spec['dest']
ip_version = test_spec.get('ip-version', None)
single_ended = test_spec.get('single-ended', False)
loopback = test_spec.get('loopback', False)
omit_delta = pscheduler.iso8601_as_timedelta(test_spec.get("omit", "P0D"))
omit = int(pscheduler.timedelta_as_seconds(omit_delta))
parallel = test_spec.get('parallel', 1)
participants = len(participant_data)
if not(participants == 2 or (participants == 1 and (single_ended or loopback))):
    pscheduler.fail('ethr requires exactly 2 participants, got %s' % (len(participant_data)))

# Make sure all participants report a version that this one is
# compatible with.
# TODO: This assumes we're going to be on 1.x for awhile.
num_compatible = len(list(filter(
    lambda v: v.get('ethr_major_version', 0) > 0,
    participant_data)))
if num_compatible < participants:
    pscheduler.succeed_json({
        'succeeded': False,
        'error': 'Cannot interoperate with systems using older, incompatible versions of Ethr.'
    })



config = ethr_utils.get_config()

# look up our local ethr command path
ethr_cmd  = config['ethr_cmd']

# grab the server port from the test spec
if single_ended:
    server_port = test_spec.get('single-ended-port', DEFAULT_SERVER_PORT)
else:
    if loopback:
        server_port = DEFAULT_SERVER_PORT
    else:
        server_port = participant_data[1]['server_port']


# convert from ISO to seconds for test duration
test_duration = test_spec.get('duration')
if test_duration:
    delta = pscheduler.iso8601_as_timedelta(test_duration)
    test_duration = int(pscheduler.timedelta_as_seconds(delta))
else:
    test_duration = DEFAULT_DURATION

# if we're doing a reverse test, have to change a few things
reverse = test_spec.get('reverse', False)

# Arguments everybody gets

ethr_first_args = [
    config['ethr_cmd'],
    '-no'
]

if ip_version is not None:
    ethr_first_args.append(f'-{ip_version}')

   

def run_client():

    diags = [
        'Plugin is using Ethr 1.x',
        ''
    ]

    ethr_args = ethr_first_args.copy()

    ethr_args.extend([
        # Test bandwidth (the default)
        '-t', 'b',
        '-port', server_port
    ])
    

    # Determine if we need to bind to an address and have enough info to do so intelligently

    normalized_dest = destination
    local_address = test_spec.get('local-address', None)
    global ip_version
    if ip_version is not None:

        #use whatever was provided as the bind
        # TODO: These need to be an IP
        if local_address is not None:
            bind_to, _ = pscheduler.ip_normalize_version(local_address, destination)
        elif source is not None:
            bind_to, _ = pscheduler.ip_normalize_version(source, destination)
        else:
            bind_to = None
        if bind_to is not None:
            ethr_args.extend(['-ip', bind_to])

        ethr_args.extend(['-c', destination])


    elif local_address is not None:

        # if we have the local-address and dest, we can determine what ip version they have in common
        # we prefer IPv6 but fallback to IPv4 if they don't both have IPv6. If one end only 
        # has IPv4 and the other only has IPv6 we don't bind at all but don't throw error in 
        # case there is some external factor we don't know about (maybe should change this?)
        local_ip, dest_ip = pscheduler.ip_normalize_version(local_address, destination)
        if local_ip is not None:
            ethr_args.extend(['-ip', local_ip])
        if dest_ip is not None:
            ethr_args.extend(['-c', dest_ip])
            normalized_dest = dest_ip
            ip_version = pscheduler.ip_addr_version(dest_ip)[0]
        else:
            return {'succeeded': False,
                    'error': 'Unable to find common IP version between local-address %s and dest %s' % (local_address, destination)}

    elif source is not None:

        #see comment for local-address, but s/local-address/source/g
        source_ip, dest_ip = pscheduler.ip_normalize_version(source, destination)
        if source_ip is not None:
            ethr_args.extend(['-ip', source_ip])
        if dest_ip is not None:
            ethr_args.extend(['-c', dest_ip])
            normalized_dest = dest_ip
            ip_version = pscheduler.ip_addr_version(dest_ip)[0]
        else:
            return {'succeeded': False,
                    'error': 'Unable to find common IP version between source %s and dest %s' % (source, destination)}

    else:

        #set destination
        ethr_args.extend(['-c', destination])

    #set duration
    ethr_args.extend(['-d', f'{test_duration}s'])

    # big list of optional arguments to ethr, map from test spec

    if 'bandwidth' in test_spec and test_spec['bandwidth'] != None:
        ethr_args.extend(['-b', test_spec['bandwidth']])

    if 'buffer-length' in test_spec and test_spec['buffer-length'] != None:
        ethr_args.extend(['-l', test_spec['buffer-length']])

    if parallel is not None:
        ethr_args.extend(['-n', parallel])

    # Protocol
    ethr_args.extend([
        '-p',
        'udp' if test_spec.get('udp', False) else 'tcp'
    ])

    if reverse:
        ethr_args.append('-r')

    if 'ip-tos' in test_spec and test_spec['ip-tos'] != None:
        ethr_args.append('-tos')
        ethr_args.append(test_spec['ip-tos'])


    # join and run_program want these all to be string types, so
    # just to be safe cast everything in the list to a string
    ethr_args = [str(x) for x in ethr_args]

    command_line = ' '.join(ethr_args)
    logger.debug('Client: Running command: %s' % command_line)

    ethr_timeout = test_duration
    ethr_timeout += ethr_utils.setup_time(test_spec.get('link-rtt'))
    # no need for the sleep, we already did that above

    logger.debug('Client: timeout for client is %d' % ethr_timeout)
    diags.append(command_line)

    try:
        start_at = input['schedule']['start']
        logger.debug('Client: Sleeping until %s', start_at)
        pscheduler.sleep_until(start_at)
        logger.debug('Client: Starting')
    except KeyError:
        pscheduler.fail('Unable to find start time in input')

    logger.debug('Client: Waiting %s sec for server on other side to start' % DEFAULT_WAIT_SLEEP)
    time.sleep(DEFAULT_WAIT_SLEEP) #wait for server to start on other side

    try:
        status, stdout, stderr = pscheduler.run_program(ethr_args,
                                                        timeout = ethr_timeout)
    except Exception as ex:
        logger.error('ethr failed to complete execution: %s' % ex)
        return {'succeeded': False,
                'diags': '\n'.join(diags),
                'error': 'The ethr command failed during execution. See server logs for more details.'}

    return _make_result('\n'.join(diags), status, stdout, stderr)

    

def run_server():

    diags = []

    ethr_args = ethr_first_args.copy()

    ethr_args.extend([
        '-s',
        '-port', server_port
    ])

    if source is not None:
        source_ip, dest_ip = pscheduler.ip_normalize_version(source, destination)
        ethr_args.extend([
            '-ip', dest_ip
        ])        

    ethr_args = [str(x) for x in ethr_args]
    logger.debug('Server: Running command: %s' % ' '.join(ethr_args))

    stdout = ''
    stderr = ''
    status = 0

    ethr_timeout = test_duration
    ethr_timeout += ethr_utils.setup_time(test_spec.get('link-rtt'))
    ethr_timeout += DEFAULT_WAIT_SLEEP

    logger.debug('Server: Timeout for server is %d' % ethr_timeout)
    diags.append(' '.join(ethr_args))

    try:
        start_at = input['schedule']['start']
        logger.debug('Server: Sleeping until %s', start_at)
        pscheduler.sleep_until(start_at)
        logger.debug('Server: Starting')
    except KeyError:
        pscheduler.fail('Unable to find start time in input')

    status, stdout, stderr = pscheduler.run_program(ethr_args,
                                                    timeout=ethr_timeout)

    return _make_result('\n'.join(diags), status, stdout, stderr)

    

def _make_result(diags, status, stdout, stderr):

    if status:
        error_text = ''
        try:
            json_stdout = pscheduler.json_load(stdout)
            error_text = json_stdout['error']            
        except (ValueError, KeyError):
            pass

        if not error_text:
            error_text = '%s\n\n%s\n' % (stdout, stderr)

        return {'succeeded': False,
                'diags': diags,
                'error': 'ethr returned an error: %s' % (error_text)}

    matcher = re.compile(r'^\[\s*(\d+|SUM)\s*\]\s+(TCP|UDP)\s+(\d+)-(\d+)\s+sec\s+([^\s]+)\s*$')

    def group_block(stream, start, end, bps):

        result = {
            'start': int(start),
            'end': int(end),
            'omitted': end <= omit,
            'throughput-bits': bps,
            'throughput-bytes': int(bps / 8)
        }

        if stream != 'SUM':
            result['stream-id'] = int(stream)

        return result


    intervals = []
    streams = []

    # Split up the lines and look for things that are interesting.
    # Note that the JSON output doesn't provide enough, so we're stuck
    # picking through the human-readable output.

    for line in stdout.split('\n'):

        match = matcher.match(line)
        if match:
            stream = match.group(1)
            start = int(match.group(3))
            end = int(match.group(4))
            bps = int(pscheduler.si_as_number(match.group(5)))

            block = group_block(stream, start, end, bps)

            if stream == 'SUM' or parallel == 1:

                if parallel == 1:
                    streams.append(block)

                intervals.append({
                    'streams': streams,
                    'summary': block
                })

                # Reset for next time.
                streams = []
            else:
                streams.append(block)


    # Make all of the lines into interval blocks and summaries

    stream_summary = {}
    full_summary = group_block('SUM', 0, 0, 0)
    end = len(intervals)
    used_intervals = end - omit
    full_summary['end'] = end
    del full_summary['omitted']

    full_bits = 0

    for interval in intervals:

        for stream in interval['streams']:

            if stream['omitted']:
                continue

            stream_id = stream['stream-id']


            # Note that what we add here is a fraction of the number
            # of intervals so what we get is the average rate for the
            # stream rather than the full sum of how much data was
            # pushed.

            bits = int(stream['throughput-bits'] / used_intervals) if used_intervals else 0

            try:
                stream_summary[stream_id]['throughput-bits'] += bits
                stream_summary[stream_id]['throughput-bytes'] += int(bits / 8)
            except KeyError:
                stream_summary[stream_id] = group_block(stream_id, 0, end, bits)

            full_bits += stream['throughput-bits']

    full_bits = int(full_bits / used_intervals) if used_intervals else 0

    full_summary['throughput-bits'] = full_bits
    full_summary['throughput-bytes'] = int(full_bits / 8)

    return {
        'succeeded': True,
        'diags': f'{diags}\n{stdout}',
        'intervals': intervals,
        'summary': {
            'streams': [ val for key, val in list(stream_summary.items()) ],
            'summary': full_summary
        }
    }            



    


#determine whether we are the client or server mode for ethr
results = {}
try:
    if participant == 0:

        if loopback:

            # Loopback - Run both.

            logger.debug('Running loopback')

            server_thread = pscheduler.ThreadWithReturnValue(target=run_server)
            server_thread.start()
            client_results = run_client()
            server_results = server_thread.join()

            if client_results['succeeded']:

                logger.debug('Loopback client succeeded; using that result.')
                results = client_results

            elif server_results['succeeded']:

                logger.debug('Loopback server succeeded; using that result.')
                results = server_results

            else:

                logger.debug('Nothing succeeded.')
                results = { 
                    'succeeded': False,
                    'error': f'''Client:\n\n{client_results.get('error', 'No error.')}''' \
                        f'''\n\nServer:\n\n{server_results.get('error', 'No error.')}'''
                }

            results['diags'] = f'''Client:\n\n{client_results.get('diags', 'No diagnostics.')}''' \
                    f'''\n\nServer:\n\n{server_results.get('diags', 'No diagnostics.')}'''

        else:

            # Non-loopback client

            logger.debug('Running client')
            results = run_client()

    elif participant == 1:

        # Non-loopback server

        logger.debug('Running server')
        results = run_server()

    else:
        pscheduler.fail('Invalid participant.')
except Exception as ex:
    logger.exception()

logger.debug('Results: %s', results)

pscheduler.succeed_json(results)
