#!/usr/bin/env python3
#
# Run an iperf3 test
#

import datetime
import logging
import json
import os
import pscheduler
import re
import shutil
import stat
import sys
import time
import threading
import iperf3_parser
import packaging.version
import tempfile
import traceback
import iperf3_utils
from iperf3_defaults import *

# track when this run starts
start_time = datetime.datetime.now()

logger = pscheduler.Log(prefix='tool-iperf3', quiet=True)
logger.debug("starting iperf3 tool")

# parse JSON input
input = pscheduler.json_load(exit_on_error=True)

logger.debug("Input is %s" % input)

try:
    participant = input['participant']
    participant_data = input['participant-data']
    test_spec = input['test']['spec']    
except KeyError as ex:
    pscheduler.fail("Missing required key in run input: %s" % ex)
except Exception:
    pscheduler.fail("Error parsing run input: %s" % sys.exc_info()[0])

single_ended = test_spec.get('single-ended', False)
loopback = test_spec.get('loopback', False)
participants = len(participant_data)
if not(participants == 2 or (participants == 1 and (single_ended or loopback))):
    pscheduler.fail("iperf3 requires exactly 2 participants, got %s" % (len(participant_data)))

config = iperf3_utils.get_config()

# look up our local iperf3 command path
iperf3_cmd  = config["iperf3_cmd"]

# grab the server port from the test spec
if single_ended:
    server_port = test_spec.get('single-ended-port', DEFAULT_SERVER_PORT)
else:
    if loopback:
        server_port = DEFAULT_SERVER_PORT
    else:
        server_port = participant_data[1]['server_port']

# need this to pad out the wait timeout,
# convert to seconds
omit_secs = test_spec.get('omit', 0)
if omit_secs:
    omit_secs = pscheduler.timedelta_as_seconds(pscheduler.iso8601_as_timedelta(omit_secs))         
omit_secs = int(omit_secs)

# convert from ISO to seconds for test duration
test_duration = test_spec.get('duration')
if test_duration:
    delta = pscheduler.iso8601_as_timedelta(test_duration)
    test_duration = int(pscheduler.timedelta_as_seconds(delta))
else:
    test_duration = DEFAULT_DURATION

# if we're doing a reverse test, have to change a few things
reverse = test_spec.get('reverse', False)


iperf3_first_args = [ iperf3_cmd ]


if participants > 1:

    # iperf3 3.17 broke compatibility with 3.16 to fix a security
    # vulnerability but has a switch to allow interoperability with older
    # versions.

    other_participant = 1 if participant == 0 else 0

    # Any pScheduler that doesn't provide iperf3-version in its
    # participant data is presumed to be using an iperf3 earlier than
    # 3.17 that we'll call 0.0.

    my_version = packaging.version.Version(participant_data[participant].get('iperf3-version', '0.0'))
    other_version = packaging.version.Version(participant_data[other_participant].get('iperf3-version', '0.0'))
    logger.debug(f'Local iperf3 is {my_version}, other is {other_version}')

    # If the other end is believed-vulnerable and the local iperf3 has
    # the compatibility switch, use it.

    if other_version < packaging.version.Version('3.17') and my_version >= packaging.version.Version('3.17'):
        logger.debug(f'Other side is using an old iperf3 ({other_version}); running in vulnerable mode')
        iperf3_first_args.append('--use-pkcs1-padding')

else:

    # Where there's one participant, we don't care.  Loopback will
    # work correctly because it's the same binary on both ends.  For
    # single-ended, opt for whatever the default behavior is because
    # we can't tell what's on the other end.  This may cause those
    # tests to break.

    logger.debug(f'Only one participant; running in iperf3 default mode')



def run_client():

    diags = []

    global iperf3_first_args
    iperf3_args = iperf3_first_args.copy()

    iperf3_args.append('-p')
    iperf3_args.append(server_port)
    
    #Determine if we need to bind to an address and have enough info to do so intelligently
    ip_version = test_spec.get('ip-version', None)
    source = test_spec.get('source', None)
    destination = test_spec['dest'] #required
    normalized_dest = destination
    local_address = test_spec.get('local-address', None)
    if ip_version is not None:
        #Set ip version if provided
        if ip_version == 4:
            iperf3_args.append('-4')
        else:
            iperf3_args.append('-6')
        #use whatever was provided as the bind
        if local_address is not None:
            iperf3_args.append('-B')
            iperf3_args.append(local_address)
        elif source is not None:
            iperf3_args.append('-B')
            iperf3_args.append(source)
        iperf3_args.append('-c')
        iperf3_args.append(destination) 
    elif local_address is not None:
        # if we have the local-address and dest, we can determine what ip version they have in common
        # we prefer IPv6 but fallback to IPv4 if they don't both have IPv6. If one end only 
        # has IPv4 and the other only has IPv6 we don't bind at all but don't throw error in 
        # case there is some external factor we don't know about (maybe should change this?)
        local_ip, dest_ip = pscheduler.ip_normalize_version(local_address, destination)
        if local_ip is not None:
            iperf3_args.append('-B')
            iperf3_args.append(local_ip)
        if dest_ip is not None:
            iperf3_args.append('-c')
            iperf3_args.append(dest_ip)
            normalized_dest = dest_ip
            ip_version = pscheduler.ip_addr_version(dest_ip)[0]
        else:
            return {"succeeded": False,
                    "error": "Unable to find common IP version between local-address %s and dest %s" % (local_address, destination)}
    elif source is not None:
        #see comment for local-address, but s/local-address/source/g
        source_ip, dest_ip = pscheduler.ip_normalize_version(source, destination)
        if source_ip is not None:
            iperf3_args.append('-B')
            iperf3_args.append(source_ip)
        if dest_ip is not None:
            iperf3_args.append('-c')
            iperf3_args.append(dest_ip)
            normalized_dest = dest_ip
            ip_version = pscheduler.ip_addr_version(dest_ip)[0]
        else:
            return {"succeeded": False,
                    "error": "Unable to find common IP version between source %s and dest %s" % (source, destination)}
    else:
        #set destination
        iperf3_args.append('-c')
        iperf3_args.append(destination)       

    #set duration
    iperf3_args.append('-t')
    iperf3_args.append(test_duration)

    # always ask for json output, a lot more consistent to parse
    iperf3_args.append("--json")

    # big list of optional arguments to iperf3, map from test spec
    if 'interval' in test_spec and test_spec['interval'] != None:
        iperf3_args.append('-i')
        delta = pscheduler.iso8601_as_timedelta(test_spec['interval'])
        iperf3_args.append(int(pscheduler.timedelta_as_seconds(delta)))

    if 'parallel' in test_spec and test_spec['parallel'] != None:
        iperf3_args.append('-P')
        iperf3_args.append(test_spec['parallel'])

    if 'window-size' in test_spec and test_spec['window-size'] != None:
        iperf3_args.append('-w')
        iperf3_args.append(test_spec['window-size'])

    if 'mss' in test_spec and test_spec['mss'] != None:
        iperf3_args.append('-M')
        iperf3_args.append(test_spec['mss'])

    if omit_secs:
        iperf3_args.append('-O')
        iperf3_args.append(omit_secs)

    if 'bandwidth' in test_spec and test_spec['bandwidth'] != None:
        iperf3_args.append('-b')
        iperf3_args.append(test_spec['bandwidth'])

    if 'fq-rate' in test_spec and test_spec['fq-rate'] != None:
        iperf3_args.append('--fq-rate')
        iperf3_args.append(test_spec['fq-rate'])

    if test_spec.get('udp', False):
        iperf3_args.append('-u')

    if 'buffer-length' in test_spec and test_spec['buffer-length'] != None:
        iperf3_args.append('-l')
        iperf3_args.append(test_spec['buffer-length'])

    if 'ip-tos' in test_spec and test_spec['ip-tos'] != None:
        iperf3_args.append('-S')
        iperf3_args.append(test_spec['ip-tos'])

    if 'congestion' in test_spec and test_spec['congestion'] != None:
        iperf3_args.append('-C')
        iperf3_args.append(test_spec['congestion'])

    if 'zero-copy' in test_spec:
        iperf3_args.append('-Z')

    if 'flow-label' in test_spec and test_spec['flow-label'] != None:
        iperf3_args.append('-L')
        iperf3_args.append(test_spec['flow-label'])
    
    if reverse:
        iperf3_args.append('-R')

    # figure out cpu affinity, either use what was passed in
    # or try to auto detect it
    affinity = test_spec.get('client-cpu-affinity')
    if affinity is not None:
        numa_ok, numa_diags = pscheduler.numa_cpu_core_ok(affinity)
        if numa_ok:
            logger.debug("Client: Selected CPU affinity %s" % affinity)
            iperf3_args.insert(0, 'numactl')
            iperf3_args.insert(1, '-C')
            iperf3_args.insert(2, affinity)
        else:
            logger.debug("Client: NUMA doesn't work for source.  Throwing caution to the wind.")
            diags.append("Unable to use NUMA for this test.  Disabling it.")
    else:
        if source:
            interface = pscheduler.address_interface(source, ip_version=ip_version)
            if interface is not None:
                logger.debug("Client: Affinity (if any) will be based on source")
                affinity = pscheduler.interface_affinity(interface)
        else:
            affinity = pscheduler.source_affinity(normalized_dest, ip_version=ip_version)
            logger.debug("Client: Affinity (if any) will be based on destination")

        if affinity is not None:
            numa_ok, numa_diags = pscheduler.numa_node_ok(affinity)
            if numa_ok:
                logger.debug("Client: Selected CPU affinity %s" % affinity)
                iperf3_args.insert(0, 'numactl')
                iperf3_args.insert(1, '-N')
                iperf3_args.insert(2, affinity)
            else:
                logger.debug("Client: NUMA doesn't work for source.  Throwing caution to the wind.")
                diags.append("Unable to use NUMA for this test.  Disabling it.")

    # Set up for authorization if that's present.

    auth_dir = tempfile.TemporaryDirectory()
    os.chmod(auth_dir.name, stat.S_IRWXU)

    auth_env = {}

    if not (single_ended or loopback) and (participant_data[1].get("_auth") is not None):

        try:

            # TODO: There should be a better way to do this.
            auth_env["IPERF3_PASSWORD"] = participant_data[1]["_auth"]["password"]

            public_key_path = "{}/public-key".format(auth_dir.name)
            with open(public_key_path, "w") as writer:
                writer.write(participant_data[1]["_auth"]["public"])
            iperf3_args.append("--rsa-public-key-path")
            iperf3_args.append(public_key_path)

            iperf3_args.append("--username")
            iperf3_args.append(participant_data[1]["_auth"]["username"])

        except Exception as ex:

            auth_dir.cleanup()
            pscheduler.fail("Unable to write client auth files: %s" % (str(ex)))


    # join and run_program want these all to be string types, so
    # just to be safe cast everything in the list to a string
    iperf3_args = [str(x) for x in iperf3_args]

    command_line = " ".join(iperf3_args)
    logger.debug("Client: Running command: %s" % command_line)

    iperf_timeout = test_duration
    iperf_timeout += omit_secs
    iperf_timeout += iperf3_utils.setup_time(test_spec.get("link-rtt"))
    # no need for the sleep, we already did that above

    logger.debug("Client: timeout for client is %d" % iperf_timeout)
    diags.append(command_line)

    try:
        start_at = input['schedule']['start']
        logger.debug("Client: Sleeping until %s", start_at)
        pscheduler.sleep_until(start_at)
        logger.debug("Client: Starting")
    except KeyError:
        pscheduler.fail("Unable to find start time in input")

    logger.debug("Client: Waiting %s sec for server on other side to start" % DEFAULT_WAIT_SLEEP)
    time.sleep(DEFAULT_WAIT_SLEEP) #wait for server to start on other side

    try:
        status, stdout, stderr = pscheduler.run_program(iperf3_args,
                                                        env_add = auth_env,
                                                        timeout = iperf_timeout)
    except Exception as ex:
        logger.error("iperf3 failed to complete execution: %s" % ex)
        return {"succeeded": False,
                "diags": "\n".join(diags),
                "error": "The iperf3 command failed during execution. See server logs for more details."}

    return _make_result("\n".join(diags), status, stdout, stderr)

    

def run_server():

    diags = []
    #init command

    global iperf3_first_args
    iperf3_args = iperf3_first_args.copy()

    iperf3_args.extend([
        '-s',
        '-1',
        '--idle-timeout', str(test_duration),
        '--json',
        '-p', server_port
    ])

    #Determine if we need to bind to an address and have enough info to do so intelligently
    ip_version = test_spec.get('ip-version', None)
    source = test_spec.get('source', None)
    normalized_dest = test_spec['dest']
    if ip_version is not None:
        #if the ip version is given we just use the dest as the bind address and have 
        #iperf3 force ipv4 or ipv6
        iperf3_args.append('-B')
        iperf3_args.append(test_spec['dest'])
        if test_spec['ip-version'] == 4:
            iperf3_args.append('-4')
        else:
            iperf3_args.append('-6')
    elif (source is not None) and (not loopback):
        # if we have the source and dest, we can determine what ip version they have in common
        # we prefer IPv6 but fallback to IPv4 if they don't both have IPv6. If one end only 
        # has IPv4 and the other only has IPv6 we don't bind at all but don't throw error in 
        # case there is some external factor we don't know about (maybe should change this?)
        source_ip, bind_address = pscheduler.ip_normalize_version(source, test_spec['dest'])
        if bind_address is not None:
            iperf3_args.append('-B')
            iperf3_args.append(bind_address)
            normalized_dest = bind_address
            ip_version = pscheduler.ip_addr_version(bind_address)[0]
        else:
            logger.debug("Server: Skipping bind. Source %s and destination %s don't have any common IPs of the same version as far as tool can tell." % (source, test_spec['dest']))
    else:
        #If we are here its because we have a TCP test where only the dest was specified 
        # and no ip-version provided. In this case we can't determine what we're supposed 
        # to bind to, so don't specify and iperf3 will bind to all interfaces
        logger.debug("Server: Skipping bind. No source or ip-version given so can't make an intelligent decision about where to bind. iperf3 will listen on all interfaces")
        pass
    
    # try to grab our default affinity if one wasn't passed in
    affinity = test_spec.get('server-cpu-affinity')
    if affinity is not None:
        numa_ok, numa_diags = pscheduler.numa_cpu_core_ok(affinity)
        if numa_ok:
            logger.debug("Server: Selected CPU affinity %s" % affinity)
            iperf3_args.insert(0, 'numactl')
            iperf3_args.insert(1, '-C')
            iperf3_args.insert(2, affinity)
        else:
            logger.debug("Server: NUMA doesn't work for destination.  Throwing caution to the wind.")
            diags.append("Unable to use NUMA for this test.  Disabling it.")
    else:
        # look up what interface we're going to be receiving on        
        interface = pscheduler.address_interface(normalized_dest, ip_version=ip_version)
        if interface:
            affinity = pscheduler.interface_affinity(interface)
            logger.debug("Server: CPU affinity for interface is %s" % (affinity))

            if affinity is not None:
                numa_ok, numa_diags = pscheduler.numa_node_ok(affinity)
                if numa_ok:
                    logger.debug("Server: Selected CPU affinity %s" % affinity)
                    iperf3_args.insert(0, 'numactl')
                    iperf3_args.insert(1, '-N')
                    iperf3_args.insert(2, affinity)
                else:
                    logger.debug("Server: NUMA doesn't work for %s.  Throwing caution to the wind." % (interface))
                    diags.append("Unable to use NUMA for %s.  Disabling it." % (interface))


    # Set up for authentication if that's present

    auth_dir = tempfile.TemporaryDirectory()
    os.chmod(auth_dir.name, stat.S_IRWXU)


    # Run auth if the lead participant is and there's auth data

    if not (single_ended or loopback) and (participant_data[1].get("_auth") is not None and participant_data[0].get("schema", 1) >= 2):

        try:

            credentials_path = "{}/credentials".format(auth_dir.name)
            with open(credentials_path, "w") as writer:
                writer.write(participant_data[1]["_auth"]["credentials"])
            iperf3_args.append("--authorized-users-path")
            iperf3_args.append(credentials_path)

            private_key_path = "{}/private-key".format(auth_dir.name)
            with open(private_key_path, "w") as writer:
                writer.write(participant_data[1]["_auth"]["private"])
            iperf3_args.append("--rsa-private-key-path")
            iperf3_args.append(private_key_path)

        except Exception as ex:

            auth_dir.cleanup()
            pscheduler.fail("Unable to write auth files: %s" % (str(ex)))

        

    iperf3_args = [str(x) for x in iperf3_args]
    logger.debug("Server: Running command: %s" % " ".join(iperf3_args))

    stdout = ""
    stderr = ""
    status = 0

    iperf_timeout = test_duration
    iperf_timeout += omit_secs
    iperf_timeout += iperf3_utils.setup_time(test_spec.get("link-rtt"))
    iperf_timeout += DEFAULT_WAIT_SLEEP

    logger.debug("Server: Timeout for server is %d" % iperf_timeout)
    diags.append(" ".join(iperf3_args))

    try:
        start_at = input['schedule']['start']
        logger.debug("Server: Sleeping until %s", start_at)
        pscheduler.sleep_until(start_at)
        logger.debug("Server: Starting")
    except KeyError:
        auth_dir.cleanup()
        pscheduler.fail("Unable to find start time in input")

    status, stdout, stderr = pscheduler.run_program(iperf3_args,
                                                    timeout=iperf_timeout)

    auth_dir.cleanup()

    # There's a bug in the server side that dumps a diagnostic into
    # stdout.  Correct for that.
    # See https://github.com/esnet/iperf/issues/1086
    # TODO: Remove when the bug is fixed.

    pos = stdout.find("{")
    if pos > 0:
        stdout = stdout[pos:]


    return _make_result("\n".join(diags), status, stdout, stderr)

    

def _make_result(diags, status, stdout, stderr):

    # Remove anything sensitive.
    stdout = iperf3_utils.strip_output(stdout)

    if status:
        error_text = ""
        try:
            json_stdout = pscheduler.json_load(stdout)
            error_text = json_stdout["error"]            
        except (ValueError, KeyError):
            pass

        if not error_text:
            error_text = "%s\n\n%s\n" % (stdout, stderr)

        return {"succeeded": False,
                "diags": diags,
                "error": "iperf3 returned an error: %s" % (error_text)}

    lines = stdout.split("\n")    

    results = iperf3_parser.parse_output(lines)
    results['diags'] = "%s\n\n%s" % (diags, stdout)

    return results


#determine whether we are the client or server mode for iperf
results = {}
try:
    if participant == 0:

        if loopback:

            # Loopback - Run both.

            logger.debug("Running loopback")

            server_thread = pscheduler.ThreadWithReturnValue(target=run_server)
            server_thread.start()
            client_results = run_client()
            server_results = server_thread.join()


            if client_results["succeeded"]:

                logger.debug("Loopback client succeeded; using that result.")
                results = client_results

            elif server_results["succeeded"]:

                logger.debug("Loopback server succeeded; using that result.")
                results = server_results

            else:

                logger.debug("Nothing succeeded.")
                results = { 
                    "succeeded": False,
                    "error": f'Client:\n\n{client_results.get("error", "No error.")}' \
                             f'\n\nServer:\n\n{server_results.get("error", "No error.")}'
                }

            results["diags"] = f'Client:\n\n{client_results.get("diags", "No diagnostics.")}' \
                    f'\n\nServer:\n\n{server_results.get("diags", "No diagnostics.")}'

        else:

            # Non-loopback client

            logger.debug("Running client")
            results = run_client()

    elif participant == 1:

        # Non-loopback server

        logger.debug("Running server")
        results = run_server()

    else:
        pscheduler.fail("Invalid participant.")
except Exception as ex:
    logger.exception()

logger.debug("Results: %s", results)

pscheduler.succeed_json(results)
