#!/usr/bin/env python3

#
# Development Order #5:
#
# This is the meat and bones of the tool, where the actual desired
# commands or operation will be run. The results are then recorded
# and added to the 'results' JSON data, which will then be sent
# back to the test. Both system and api are able to be used here.
#

import datetime
import logging
import json
import pscheduler
import re
import shutil
import sys
import time
import traceback
import threading
from ethr_defaults import *

# track when this run starts
start_time = datetime.datetime.now()

logger = pscheduler.Log(prefix='tool-ethr', quiet=True)

# parse JSON input
input = pscheduler.json_load(exit_on_error=True)

logger.debug("Input is %s" % input)

try:
    participant = input['participant']
    participant_data = input['participant-data']
    test_spec = input['test']['spec']    
except KeyError as e:
    pscheduler.fail("Missing required key in run input: %s" % e)
except Exception:
    logger.exception()
    pscheduler.fail("Error parsing run input: %s" % sys.exc_info()[0])

protocol = "udp" if test_spec.get("udp", False) else "tcp"

loopback = test_spec.get('loopback', False)
if loopback:
    server_port = DEFAULT_SERVER_PORT
else:
    server_port = participant_data[1]['server_port']

try:
    parallel = int(test_spec["parallel"])
    # TODO: Enable this or get rid of it.
    #parallel = min(
    #    participant_data[0].get("cores", 1),
    #    participant_data[1].get("cores", 1)
    #)
except KeyError:
    parallel = 1

omit_delta = pscheduler.iso8601_as_timedelta(test_spec.get("omit", "P0D"))
omit = int(pscheduler.timedelta_as_seconds(omit_delta))


# convert from ISO to seconds for test duration
test_duration = test_spec.get('duration')
if test_duration:
    delta = pscheduler.iso8601_as_timedelta(test_duration)
    test_duration = int(pscheduler.timedelta_as_seconds(delta))
else:
    test_duration = DEFAULT_DURATION

test_duration += omit

# if we're doing a reverse test, have to change a few things
reverse = test_spec.get('reverse', False)


def run_client():

    diags = [
        'Plugin is using Ethr 0.x',
        ''
    ]

    ethr_args = [

        "ethr0",

        # No logs
        "-no",

        # Client
        "-c", test_spec['dest'],

        # Port
        "-ports", "%s=%d" % (protocol, server_port),

        # Duration
        "-d", "%ss" % (test_duration),

        # Test type (bandwidth)
        "-t", "b",

        # Protocol
        "-p", protocol
    ]

    # Parallel threads
    if parallel > 1:
        ethr_args.append('-n')
        ethr_args.append(parallel)

    # Buffer length
    if 'buffer-length' in test_spec and test_spec['buffer-length'] != None:
        ethr_args.append('-l')
        ethr_args.append(test_spec['buffer-length'])

    # IP version
    ip_version = test_spec.get('ip-version', None)
    if ip_version is not None:
        #Set ip version if provided
        if ip_version == 4:
            ethr_args.append('-4')
        else:
            ethr_args.append('-6')

    # Reverse
    if reverse:
        ethr_args.append('-r')

    # join and run_program want these all to be string types, so
    # just to be safe cast everything in the list to a string
    ethr_args = [str(x) for x in ethr_args]

    command_line = " ".join(ethr_args)

    ethr_timeout = test_duration + 2
    logger.debug("timeout for client is %d" % ethr_timeout)

    try:
        start_at = input['schedule']['start']
        logger.debug("Sleeping until %s", start_at)
        pscheduler.sleep_until(start_at)
        logger.debug("Starting")
    except KeyError:
        pscheduler.fail("Unable to find start time in input")

    # TODO: The time for this should be DEFAULT_WAIT_SLEEP
    logger.debug("Waiting %s sec for server on other side to start" % 3)
    time.sleep(3) #wait for server to start on other side

    logger.debug("Running command: %s" % command_line)
    diags.append(command_line)

    try:
        status, stdout, stderr = pscheduler.run_program(ethr_args,
                                                        timeout = ethr_timeout)
    except Exception as e:
        logger.exception()
        logger.error("ethr failed to complete execution: %s" % e)
        pscheduler.succeed_json({"succeeded": False,
                                 "diags": "\n".join(diags),
                                 "error": "The ethr command failed during execution. See server logs for more details."})

    diags.append("\nExited: %s\n" % (status))
    diags.append("\nStandard output:\n")
    diags.append(stdout)
    diags.append("\nStandard error:\n")
    diags.append(stderr)


    # Make the result and return it

    matcher = re.compile(r'^\[([^]]+)\]\s+(TCP|UDP)\s+([^\s]+)-([^\s]+)\s+[^\s]+\s*([^\s]+)$')


    def group_block(stream, start, end, bps):

        result = {
            "start": int(start),
            "end": int(end),
            "omitted": end <= omit,
            "throughput-bits": bps,
            "throughput-bytes": int(bps / 8)
        }

        if stream != "SUM":
            result["stream-id"] = int(stream)

        return result


    intervals = []
    streams = []

    # Split up the lines and look for things that are interesting.
    # Note that the JSON output doesn't provide enough, so we're stuck
    # picking through the human-readable output.

    for line in stdout.split("\n"):

        match = matcher.match(line)

        if match:
            stream = match.group(1)
            start = int(match.group(3))
            end = int(match.group(4))
            bps = int(pscheduler.si_as_number(match.group(5)))

            block = group_block(stream, start, end, bps)

            if stream == "SUM" or parallel == 1:

                if parallel == 1:
                    streams.append(block)

                intervals.append({
                    "streams": streams,
                    "summary": block
                })

                # Reset for next time.
                streams = []
            else:
                streams.append(block)


    # Make all of the lines into interval blocks and summaries

    stream_summary = {}
    full_summary = group_block("SUM", 0, 0, 0)
    end = len(intervals)
    used_intervals = end - omit
    full_summary["end"] = end
    del full_summary["omitted"]

    full_bits = 0

    for interval in intervals:

        for stream in interval["streams"]:

            if stream["omitted"]:
                continue

            stream_id = stream["stream-id"]


            # Note that what we add here is a fraction of the number
            # of intervals so what we get is the average rate for the
            # stream rather than the full sum of how much data was
            # pushed.

            bits = int(stream["throughput-bits"] / used_intervals) if used_intervals else 0

            try:
                stream_summary[stream_id]["throughput-bits"] += bits
                stream_summary[stream_id]["throughput-bytes"] += int(bits / 8)
            except KeyError:
                stream_summary[stream_id] = group_block(stream_id, 0, end, bits)

            full_bits += stream["throughput-bits"]

    full_bits = int(full_bits / used_intervals) if used_intervals else 0

    full_summary["throughput-bits"] = full_bits
    full_summary["throughput-bytes"] = int(full_bits / 8)

    return {
        "succeeded": True,
        "diags": "\n".join(diags),
        "intervals": intervals,
        "summary": {
            "streams": [ val for key, val in list(stream_summary.items()) ],
            "summary": full_summary
        }
    }            

    

def run_server():

    diags = []

    ethr_args = [

        "ethr0",

        # Server
        "-s",

        # Port
        "-ports", "%s=%d" % (protocol, server_port)

    ]


    # IP version
    ip_version = test_spec.get('ip-version', None)
    if ip_version is not None:
        #Set ip version if provided
        if ip_version == 4:
            ethr_args.append('-4')
        else:
            ethr_args.append('-6')


    # join and run_program want these all to be string types, so
    # just to be safe cast everything in the list to a string
    ethr_args = [str(x) for x in ethr_args]

    command_line = " ".join(ethr_args)
    logger.debug("Running command: %s" % command_line)
    diags.append(command_line)

    # TODO: The time (3) for this should be DEFAULT_WAIT_SLEEP
    ethr_timeout = test_duration+3+1
    logger.debug("timeout for server is %d" % ethr_timeout)

    try:
        start_at = input['schedule']['start']
        logger.debug("Sleeping until %s", start_at)
        pscheduler.sleep_until(start_at)
        logger.debug("Starting")
    except KeyError:
        pscheduler.fail("Unable to find start time in input")

    try:
        status, stdout, stderr = pscheduler.run_program(ethr_args,
                                                        timeout=ethr_timeout,
                                                        timeout_ok=True)
    except Exception as e:
        logger.exception()
        logger.error("Ethr server failed to complete execution: %s" % e)
        pscheduler.succeed_json({"succeeded": False,
                                 "diags": "\n".join(diags),
                                 "error": "The ethr command failed during execution. See server logs for more details."})

    diags.append("\nExited: %s\n" % (status))
    diags.append("\nStandard output:\n")
    diags.append(stdout)
    # Don't show complaints if we exited zero, because it was a timeout.
    if status != 0 and stderr:
        diags.append("\nStandard error:\n")
        diags.append(stderr)

    return {
        "succeeded": True,
        "diags": "\n".join(diags)
    }




results = {}

if participant == 0:
    if loopback:
        logger.debug("Running client and server")
        server_thread = threading.Thread(target=run_server)
        server_thread.start()
        results = run_client()
        server_thread.join() #Wait until the server thread terminates
    else:
        logger.debug("Running client")
        results = run_client()
elif participant == 1:
    logger.debug("Running server")
    results = run_server()
else:
    pscheduler.fail("Invalid participant.")


logger.debug("Results: %s" % results)

pscheduler.succeed_json(results)
