#!/usr/bin/env python3
"""
Send a result to RabbitMQ.
"""

import amqp
import datetime
import pscheduler
import sys
import time
import urllib.parse


MAX_SCHEMA = 2

log_prefix="archiver-rabbitmq"

log = pscheduler.Log(prefix=log_prefix, quiet=True)


class AMQPExpiringConnection(object):

    """
    Maintains an expiring connection to RabbitMQ
    """

    def expired(self):
        """
        Determine if the connection is expired
        """
        return (self.connection_expires is not None) and (datetime.datetime.now() > self.connection_expires)


    def __disconnect(self):
        """
        INTERNAL: Close down the connection
        """
        try:
            self.connection.close()
        except Exception:
            pass  # This is best-effort.
        self.channel = None
        self.connection = None
        self.connection_expires = None


    def __connect(self):
        """
        INTERNAL: Establish a connection if one is needed
        """

        # If the connection has expired, drop it.
        if self.expired():
            self.__disconnect()

        if self.channel is not None:
            # Already connected.
            return

        self.connection = amqp.connection.Connection(
            host=self.host,
            ssl=self.ssl,
            virtual_host=self.virtual_host,

            login_method=self.auth_method,
            userid=self.userid,
            password=self.password,

            connect_timeout=self.timeout,
            read_timeout=self.timeout,
            write_timeout=self.timeout,
            confirm_publish=True
        )

        self.connection.connect()

        self.channel = amqp.channel.Channel(self.connection)
        self.channel.open()

        if self.expire_time is not None:
            self.connection_expires = datetime.datetime.now() + self.expire_time



    def __init__(self, url, key,
                 exchange='',
                 timeout=None,
                 expire_time=None
    ):
        """
        Construct a connection to AMQP
        """

        self.url = url
        self.key = key
        if not isinstance(exchange, str):
            raise ValueError("Invalid exchange.")
        self.exchange = exchange

        if timeout is not None and not isinstance(timeout, datetime.timedelta):
            raise ValueError("Invalid timeout")
        self.timeout = datetime.timedelta.total_seconds(timeout)

        if expire_time is not None and not isinstance(expire_time, datetime.timedelta):
            raise ValueError("Invalid expiration time")
        self.expire_time = expire_time

        parsed_url = urllib.parse.urlparse(url)

        # Set default port and SSL flag based on URL scheme
        if (parsed_url.scheme == "amqp"):
            port = 5672
            self.ssl = False
        elif (parsed_url.scheme == "amqps"):
            port = 5671
            self.ssl = True
        else:
            raise ValueError("URL must be amqp[s]://...")

        # Use port if specified in URL
        if parsed_url.port:
            port = parsed_url.port

        self.host = "%s:%s" % (parsed_url.hostname, port)
        # Remove leading slash from path to match pika parsing convention
        self.virtual_host = parsed_url.path[1:] or ""

        # These are the amqp module's defaults
        self.userid = parsed_url.username or "guest"
        self.password = parsed_url.password or "guest"
        self.auth_method = "AMQPLAIN" if (self.userid is not None or self.password is not None) else None

        self.connection = None
        self.channel = None

        self.connection_expires = None

        self.__connect()


    def __del__(self):
        """
        Destroy the connection
        """
        self.__disconnect()


    def publish(self, message):
        """
        Publish a message to the connection
        """
        self.__connect()
        try:
            self.channel.basic_publish_confirm(
                amqp.Message(message),
                exchange=self.exchange,
                routing_key=self.key,
                mandatory=True,
                immediate=False,
                timeout=self.timeout,
                confirm_timeout=self.timeout)
        except Exception as ex:
            # Any error means we start over next time.
            self.__disconnect()
            raise ex




connections = {}


GROOM_INTERVAL = datetime.timedelta(seconds=20)
next_groom = datetime.datetime.now() + GROOM_INTERVAL

def groom_connections():
    """
    Get rid of expired connections.  This is intended to sweep up
    connections that are no longer used.  Those in continuous use will
    self-expire and re-create themeselves.
    """

    global next_groom
    if datetime.datetime.now() < next_groom:
        # Not time yet.
        return

    log.debug("Grooming connections")
    for groom in list([ key
                        for key, connection in connections.items()
                        if connection.expired()
                    ]):
        log.debug("Dropping expired connection {}".format(groom))
        del connections[groom]

    next_groom += GROOM_INTERVAL
    log.debug("Next groom at {}".format(next_groom))



def archive(json):

    data = json["data"]

    log.debug("Archiving: %s" % data)

    schema = data.get("schema", 1)
    if schema > MAX_SCHEMA:
        return {
            "succeeded": False,
            "error": "Unsupported schema version %d; max is %d" % (
                schema, MAX_SCHEMA)
        }


    # Figure out the routing key
    routing_key_raw = data.get("routing-key", "")
    if isinstance(routing_key_raw, str):
        data["routing-key"] = routing_key_raw
    else:
        # JQ Transform
        log.debug("Using transform for routing key.")
        # This will already have been validated.
        transform = pscheduler.JQFilter(routing_key_raw)
        try:
            data["routing-key"] = str(transform(json["result"])[0])
        except pscheduler.JQRuntimeError as ex:
            return {
                "succeeded": False,
                "error": "Routing key transform failed: %s" % (str(ex))
        }
    log.debug("Routing key is '%s'" % (data["routing-key"]))


    key = None

    try:

        key = "%s```%s```%s```%s" % (
            data["_url"],
            data.get("exchange", ""),
            data.get("routing-key", ""),
            data.get("connection-expires", "")
        )

        expires = pscheduler.iso8601_as_timedelta(data.get("connection-expires","PT1H"))
        timeout = pscheduler.iso8601_as_timedelta(data.get("timeout","PT10S"))

        try:
            connection = connections[key]
        except KeyError:
            connection = AMQPExpiringConnection(data["_url"],
                                                data.get("routing-key", ""),
                                                exchange=data.get("exchange", ""),
                                                expire_time=expires,
                                                timeout=timeout
            )
            connections[key] = connection

        connection.publish(pscheduler.json_dump(json["result"]))

        result = {'succeeded': True}

    except Exception as ex:

        # The connection will self-recover from failures, so no need
        # to do anything other than complain about it.

        result = {
            "succeeded": False,
            "error": str(ex)
        }

        if "retry-policy" in data:
            policy = pscheduler.RetryPolicy(data["retry-policy"], iso8601=True)
            retry_time = policy.retry(json["attempts"])
            if retry_time is not None:
                result["retry"] = retry_time

    groom_connections()

    return result



PARSER = pscheduler.RFC7464Parser(sys.stdin)
EMITTER = pscheduler.RFC7464Emitter(sys.stdout)

for parsed in PARSER:
    try:
        EMITTER(archive(parsed))
    except BrokenPipeError as ex:
        log.warning("Broken pipe during archiving; parent must have exited.")
        pscheduler.succeed()

pscheduler.succeed()
