#!/usr/bin/env python3
#
# Run an nuttcp test
#

import datetime
import logging
import json
import pscheduler
import re
import shutil
import sys
import time
import threading
import nuttcp_parser
import traceback
import nuttcp_utils
from nuttcp_defaults import *



# track when this run starts
start_time = datetime.datetime.now()

logger = pscheduler.Log(prefix='tool-nuttcp', quiet=True)

logger.debug("starting nuttcp tool")

# parse JSON input
input = pscheduler.json_load(exit_on_error=True)

logger.debug("Input is %s" % input)

try:
    participant = input['participant']
    participant_data = input['participant-data']
    test_spec = input['test']['spec']
    duration = pscheduler.iso8601_as_timedelta(input['schedule']['duration'])
except KeyError as ex:
    pscheduler.fail("Missing required key in run input: %s" % ex)
except:
    pscheduler.fail("Error parsing run input: %s" % sys.exc_info()[0])

single_ended = test_spec.get('single-ended', False)
loopback = test_spec.get('loopback', False)
participants = len(participant_data)
if not(participants == 2 or (participants == 1 and (single_ended or loopback))):
    pscheduler.fail("nuttcp requires 1 or 2 participants, got %s" % (len(participant_data)))

config = nuttcp_utils.get_config()
nuttcp_cmd  = config["nuttcp_cmd"]



# grab the server port
if single_ended:
    server_port = test_spec.get('single-ended-port', DEFAULT_SERVER_PORT)
    data_port_start = DEFAULT_DATA_PORT_START
else:
    if loopback:
        server_port = DEFAULT_SERVER_PORT
        data_port_start = DEFAULT_DATA_PORT_START
    else:
        server_port = participant_data[1]['server_port']
        try:
            data_port_start = participant_data[1]['data_port_start']
        except KeyError:
            data_port_start = DEFAULT_DATA_PORT_START



def run_client():

    diags = []
    nuttcp_args = [ nuttcp_cmd ]

    if data_port_start is not None:
        nuttcp_args.append('-p')
        nuttcp_args.append(data_port_start)

    nuttcp_args.append('-P')
    nuttcp_args.append(server_port)


    # Determine if we need to specify a specific ip-version if we have enough info to do so intelligently
    ip_version = test_spec.get('ip-version', None)
    source = test_spec.get('source', None)

    if ip_version is None and source is not None:
        source_ip, dest_ip = pscheduler.ip_normalize_version(source, test_spec['dest'])         
        if source_ip is not None and dest_ip is not None:
            ip_version = pscheduler.ip_addr_version(source_ip)[0]
            
    if ip_version:
        nuttcp_args.append('-%s' % ip_version)


    # duration
    test_duration = test_spec.get('duration')
    if test_duration:
        delta = pscheduler.iso8601_as_timedelta(test_duration)
        test_duration = int(pscheduler.timedelta_as_seconds(delta))
    else:
        test_duration = DEFAULT_DURATION

    nuttcp_args.append('-T')
    nuttcp_args.append(test_duration)

    # big list of optional arguments to nuttcp, map from test spec
    if 'interval' in test_spec and test_spec['interval'] != None:
        nuttcp_args.append('-i')
        delta = pscheduler.iso8601_as_timedelta(test_spec['interval'])
        nuttcp_args.append(int(pscheduler.timedelta_as_seconds(delta)))

    if 'parallel' in test_spec and test_spec['parallel'] != None:
        nuttcp_args.append('-N')
        nuttcp_args.append(test_spec['parallel'])

    if 'window-size' in test_spec and test_spec['window-size'] != None:
        nuttcp_args.append('-w')
        nuttcp_args.append(test_spec['window-size'])

    if 'mss' in test_spec and test_spec['mss'] != None:
        nuttcp_args.append('-M')
        nuttcp_args.append(test_spec['mss'])


    max_bandwidth = test_spec.get('bandwidth', 0)
    nuttcp_bandwidth = max_bandwidth / 1000.0

    # Nuttcp can exceed the bandwidth limit if a burst size is in
    # effect, so don't honor burst size unless there's no bandwidth limit.

    if ('bandwidth-strict' in test_spec) or ('burst-size' in test_spec):

        if not max_bandwidth:
            bandwidth_switch = '-Ri' + str(nuttcp_bandwidth)
            if 'burst-size' in test_spec:
                bandwidth_switch += '/{}'.format(test_spec['burst-size'])
                nuttcp_args.append(bandwidth_switch)
        else:
            diags.append("Ignoring burst size in favor of bandwidth limit.")
            nuttcp_args.append('-R{}'.format(str(nuttcp_bandwidth)))

    elif max_bandwidth:

        nuttcp_args.append('-R{}'.format(str(nuttcp_bandwidth)))


    if 'udp' in test_spec:
        nuttcp_args.append('-u')
        nuttcp_args.append('-j') # add jitter reporting

    if 'buffer-length' in test_spec and test_spec['buffer-length'] != None:
        nuttcp_args.append('-l')
        nuttcp_args.append(test_spec['buffer-length'])

    if 'ip-tos' in test_spec and test_spec['ip-tos'] != None:
        nuttcp_args.append('-c')
        # The 'T' suffix treats the TOS value as the whoe octet
        # instead of just the DSCP.
        nuttcp_args.append("%sT" % (str(test_spec['ip-tos'])))

    if test_spec.get('reverse'):
        nuttcp_args.append('-F')

    if test_spec.get('reverse-connections'):
        nuttcp_args.append('-r')

    # who to connect to
    destination = test_spec['dest']

    affinity = test_spec.get('client-cpu-affinity')
    if affinity is not None:
        numa_ok, numa_diags = pscheduler.numa_cpu_core_ok(affinity)
        if numa_ok:
            logger.debug("Selected CPU affinity %s" % affinity)
            nuttcp_args.insert(0, 'numactl')
            nuttcp_args.insert(1, '-C')
            nuttcp_args.insert(2, affinity)
        else:
            logger.debug("NUMA doesn't work for source.  Throwing caution to the wind.")
            diags.append("Unable to use NUMA for this test.  Disabling it.")
    else:
        if source:
            interface = pscheduler.address_interface(source, ip_version=ip_version)
            if interface is not None:
                logger.debug("Affinity (if any) will be based on source")
                affinity = pscheduler.interface_affinity(interface)
        else:
            affinity = pscheduler.source_affinity(destination, ip_version=ip_version)
            logger.debug("Affinity (if any) will be based on destination")

        if affinity is not None:
            numa_ok, numa_diags = pscheduler.numa_node_ok(affinity)
            if numa_ok:
                logger.debug("Selected CPU affinity %s" % affinity)
                nuttcp_args.insert(0, 'numactl')
                nuttcp_args.insert(1, '-N')
                nuttcp_args.insert(2, affinity)
            else:
                logger.debug("NUMA doesn't work for source.  Throwing caution to the wind.")
                diags.append("Unable to use NUMA for this test.  Disabling it.")

    nuttcp_args.append(destination)

    # join and run_program want these all to be string types, so
    # just to be safe cast everything in the list to a string
    nuttcp_args = [str(x) for x in nuttcp_args]

    joined_args = " ".join(nuttcp_args)
    logger.debug("Running command: %s" % (joined_args))
    diags.append(joined_args)

    try:
        start_at = input['schedule']['start']
        logger.debug("Sleeping until %s", start_at)
        pscheduler.sleep_until(start_at)
        logger.debug("Starting")
    except KeyError:
        pscheduler.fail("Unable to find start time in input")

    logger.debug("Waiting %s sec for server on other side to start" % DEFAULT_WAIT_SLEEP)
    time.sleep(DEFAULT_WAIT_SLEEP) #wait for server to start on other side

    try:
        status, stdout, stderr = pscheduler.run_program(nuttcp_args)
    except Exception as ex:
        logger.error("nuttcp failed to complete execution: %s" % ex)
        pscheduler.succeed_json({"succeeded": False,
                                 "diags": "\n".join(diags),
                                 "error": "The nuttcp command failed during execution. See server logs for more details."})


    return _make_result(" ".join(nuttcp_args), diags, status, stdout, stderr)


def run_server():

    #init command
    nuttcp_args = [ nuttcp_cmd, '-S', '-1', '--nofork']
    diags = []

    if data_port_start is not None:
        nuttcp_args.append("-p")
        nuttcp_args.append(data_port_start)

    nuttcp_args.append("-P")
    nuttcp_args.append(server_port)

    # Determine if we need to specify a specific ip-version if we have enough info to do so intelligently
    ip_version = test_spec.get('ip-version', None)
    source = test_spec.get('source', None)
    if ip_version is None and source is not None:
        source_ip, dest_ip = pscheduler.ip_normalize_version(source, test_spec['dest'])         
        if source_ip is not None and dest_ip is not None:
            ip_version = pscheduler.ip_addr_version(source_ip)[0]
            
    if ip_version:
        nuttcp_args.append('-%s' % ip_version)

    # try to grab our default affinity if one wasn't passed in
    affinity = test_spec.get('server-cpu-affinity')
    if affinity is not None:
        numa_ok, numa_diags = pscheduler.numa_cpu_core_ok(affinity)
        if numa_ok:
            logger.debug("Selected CPU affinity %s" % affinity)
            nuttcp_args.insert(0, 'numactl')
            nuttcp_args.insert(1, '-C')
            nuttcp_args.insert(2, affinity)
        else:
            logger.debug("NUMA doesn't work for destination.  Throwing caution to the wind.")
            diags.append("Unable to use NUMA for this test.  Disabling it.")
    else:        
        # look up what interface we're going to be receiving on        
        interface = pscheduler.address_interface(test_spec['dest'], ip_version=ip_version)
        if interface:
            affinity = pscheduler.interface_affinity(interface)
            logger.debug("CPU affinity for interface is %s" % (affinity))

            if affinity is not None:
                numa_ok, numa_diags = pscheduler.numa_node_ok(affinity)
                if numa_ok:
                    logger.debug("Selected CPU affinity %s" % affinity)
                    nuttcp_args.insert(0, 'numactl')
                    nuttcp_args.insert(1, '-N')
                    nuttcp_args.insert(2, affinity)
                else:
                    logger.debug("NUMA doesn't work for %s.  Throwing caution to the wind." % (interface))
                    diags.append("Unable to use NUMA for %s.  Disabling it." % (interface))

    nuttcp_args = [str(x) for x in nuttcp_args]
    nuttcp_args_string = " ".join(nuttcp_args)
    logger.debug("Running command: %s" % (nuttcp_args_string))
    diags.append(nuttcp_args_string)

    try:
        start_at = input['schedule']['start']
        logger.debug("Sleeping until %s", start_at)
        pscheduler.sleep_until(start_at)
        logger.debug("Starting")
    except KeyError:
        pscheduler.fail("Unable to find start time in input")

    try:
        status, stdout, stderr = pscheduler.run_program(nuttcp_args,
                                                        timeout = pscheduler.timedelta_as_seconds(duration - (datetime.datetime.now() - start_time)),
                                                        timeout_ok = True)
    except Exception as ex:
        logger.error("nuttcp failed to complete execution: %s" % ex)
        pscheduler.succeed_json({"succeeded": False,
                                 "diags": "\n".join(diags),
                                 "error": "The nuttcp command failed during execution. See server logs for more details."})


    # in nuttcp, the client always reports the results even in reverse mode so 
    # we never need to parse the server's output
    return _make_result(" ".join(nuttcp_args), diags, status, stdout, stderr, parse=False)



def _make_result(command_line, diags, status, stdout, stderr, parse=True):
    logger.debug("Stdout = %s" % stdout)
    logger.debug("Stderr = %s" % stderr)
    
    if status:
        pscheduler.succeed_json({"succeeded": False,
                                 "diags": command_line,
                                 "error": "nuttcp returned an error: %s\n%s" % (stdout, stderr)})

    lines = stdout.split("\n")    
    logger.debug("Lines are %s " % lines)

    if parse:
        results = nuttcp_parser.parse_output(lines)
    else:
        results = { "succeeded": True }

    diags_text = "%s\n\n" % ("\n".join(diags)) if diags else ""

    results['diags'] = "%s%s\n\n%s\n%s" % (diags_text, command_line, stdout, stderr)

    return results


#determine whether we are the client or server mode for nuttcp
results = {}
try:
    if participant == 0:
        if loopback:
            server_thread = threading.Thread(target=run_server)
            server_thread.start()
            results = run_client()
            server_thread.join() #Wait until the server thread terminates
        else:
            results = run_client()
    elif participant == 1:
        results = run_server()
    else:
        pscheduler.fail("Invalid participant.")
except Exception as ex:
    _, _, ex_traceback = sys.exc_info()
    if ex_traceback is None:
        ex_traceback = ex.__traceback__
    tb_lines = [ line.rstrip('\n') for line in
                 traceback.format_exception(ex.__class__, ex, ex_traceback)]
    logger.debug(tb_lines)
    logger.error("Exception %s" % ex)

logger.debug("Results: %s" % results)

pscheduler.succeed_json(results)
