#!/usr/bin/env python3
#
# Run a test.  Just the test spec is provided on stdin.
#

from pysnmp.hlapi import *
import pysnmp
import datetime
import json
import re
import sys
import time
import pscheduler

log = pscheduler.Log(prefix="tool-pysnmp", quiet=True)

# check for missing required fields
def missing_input(spec, test_type, level):

    try:
        # general fields
        version = spec['version']
        dest = spec['dest']
        oid = spec['oid']
        polls = spec['polls']
        if version in ['1', '2c']:
            community = spec['_community']
        else:
            security_name = spec['security-name']
            security_level = spec['security-level'].lower()

            # authentication and encryption
            if security_level == 'authpriv':
                auth_protocol = spec['auth-protocol']
                auth_key = spec['_auth-key']
                priv_protocol = spec['priv-protocol']
                priv_key = spec['_priv-key']

            # encryption
            if security_level == 'authnopriv':
                auth_protocol = spec['auth-protocol']
                auth_key = spec['_auth-key']

            if security_level == 'noauthnopriv':
                pass

        # snmpdelta fields
        if polls > 1:
            period = spec['period']

        return False

    except KeyError:
        return True

# build object for snmp engine
def build_object(oid):

    if re.match(r'^((\.\d)|\d)+(\.\d+)*$', oid) is None:
        try:
            temp = oid.split('::')
            args = [temp[0]]
            temp = temp[1].split('.')
            args.extend(temp)
            try:
                obj_id = ObjectIdentity(args[0], args[1], args[2])
            except IndexError:
                obj_id = ObjectIdentity(args[0], args[1])
        except IndexError:
            pscheduler.fail('Incomplete/Invalid OID')

    else:
        obj_id = ObjectIdentity(item)

    return obj_id

def get_generator(spec, oid_tuple, level): 

    if level == None:
        credentials = CommunityData(spec['_community'])

    else:
        auth_protocols = { 'sha': usmHMACSHAAuthProtocol,
                           'md5': usmHMACMD5AuthProtocol
                         }

        priv_protocols = { 'des': usmDESPrivProtocol,
                           '3des': usm3DESEDEPrivProtocol,
                           'aes': usmAesCfb128Protocol,
                           'aes128': usmAesCfb128Protocol,
                           'aes192': usmAesCfb192Protocol,
                           'aes256': usmAesCfb256Protocol
                         }

        # completely unauthenticated request
        if level == 'noauthnopriv':
            credentials = UsmUserData(spec['security-name'])
        # authenticated and encrypted request
        elif level == 'authpriv':
            try:
                credentials = UsmUserData(spec['security-name'], spec['_auth-key'], spec['_priv-key'],
                              authProtocol=auth_protocols[spec['auth-protocol'].lower()],
                              privProtocol=priv_protocols[spec['priv-protocol'].lower()])
            except KeyError:
                pscheduler.fail("Unrecognized authentication/privacy protocol.")

        elif level == 'authnopriv':
            try:
                credentials = UsmUserData(spec['security-name'], spec['_auth-key'],
                              authProtocol=auth_protocols[spec['auth-protocol'].lower()])
            except KeyError:
                pscheduler.fail("Unrecognized authentication protocol: %s" % spec['auth-protocol'])

        else:
            pscheduler.fail("Unrecognized security level: %s" % spec['security-level'])

    return getCmd(SnmpEngine(),
                  credentials,
                  UdpTransportTarget((spec['dest'], 161)),
                  ContextData(),
                  *oid_tuple)


# 
# Preliminaries
#


# Load input from stream
input = pscheduler.json_load(exit_on_error=True)
spec = input['test']['spec']
test_type = input['test']['type']

try:
    level = spec['security-level'].lower()
except KeyError:
    level = None

if missing_input(spec, test_type, level):
    pscheduler.fail('Missing data in input')

# Set start time
start_time = datetime.datetime.now()


#
# Perform the test
#


data = []          # array of dictionaries, each corresponding to an oid in the spec
oid_tuple = ()     # method for specifying multiple oids to the engine
num_polls = 0      # tracker for number of polls
this_poll = None   # latest poll time
succeeded = None   

# add oids to tuple
for item in spec['oid']:
    oid_tuple += (ObjectType(build_object(item)),)

try:
    start_at = input['schedule']['start']
    log.debug("Sleeping until %s", start_at)
    pscheduler.sleep_until(start_at)
    log.debug("Starting")
except KeyError:
    pscheduler.fail("Unable to find start time in input")


emitter = pscheduler.RFC7464Emitter(sys.stdout)


# n comparisons require n+1 get requests
while num_polls < spec['polls']:

    try:
        g = get_generator(spec, oid_tuple, level)
        errorIndication, errorStatus, errorIndex, varBinds = next(g)
        timestamp = str(datetime.datetime.now())
        last_poll = this_poll
        this_poll = time.time()

    except (pysnmp.smi.error.MibNotFoundError, pysnmp.smi.error.SmiError, pysnmp.error.PySnmpError) as e:
        succeeded = False
        error = "snmpget returned an error: %s" % str(e).strip('\n')
        diags = ''
        data = None
        break

    if errorIndication:
        succeeded = False
        error = "snmpget returned an error: %s" % str(errorIndication).strip('\n')
        diags = ''
        data = None
        break

    for enum, result in enumerate(varBinds):
        # result[0] == object name
        # result[1] == value
        result_type = result[1].__class__.__name__
        # store as appropriate type
        try:
            value = int(result[1])
            entry = { "timestamp": str(timestamp), "type": result_type, "value": value }
        except (ValueError, TypeError):
            if result_type in ['NoSuchInstance', 'NoSuchObject']:
                entry = { "timestamp": str(timestamp), "value": result_type }
            else:
                value = str(result[1])
                entry = { "timestamp": str(timestamp), "type": result_type, "value": value }

        if spec['polls'] > 1:
            entry['seq_num'] = num_polls
            # no delta values for first pull
            if num_polls != 0:
                try:
                    entry['delta'] = value - data[enum][num_polls - 1]['value']
                    entry['timedelta'] = this_poll - last_poll
                # string types will raise this exception, only integer values have delta values
                except (TypeError, NameError):
                    pass
                except IndexError:
                    entry['delta'] = value - data[enum][0]['value']
                    entry['timedelta'] = this_poll - last_poll
        
        try:
            data[enum].append(entry)
        except IndexError:
            data.append(entry)

    num_polls += 1

    if test_type == 'snmpgetbgm':
        if len(data) > 1:
            del data[0]

        results = {
            'succeeded': True,
            'result': {
                'schema': 1,
                'succeeded' : succeeded,
                'data': data
            },
            'error': '',
            'diags': 'Run #%d' % num_polls
        }
        emitter(results)

    if spec['polls'] == 1:
        break

    while (time.time() - this_poll < spec['period']):
        pass
    
end_time = datetime.datetime.now()

if succeeded == None:
    succeeded = True
    error = None
    diags = ''


#
# Produce results
#


results = {
    'succeeded': succeeded,
    'result': {
        'schema': 1,
        'time': pscheduler.timedelta_as_iso8601( end_time - start_time ),
        'succeeded' : succeeded,
        'data': data
    },
    'error': error,
    'diags': diags
}

pscheduler.succeed_json(results)

