#!/usr/bin/env python3
"""psi2log - PSI metrics monitor and logger"""

from time import sleep, monotonic
from ctypes import CDLL
from sys import stdout, exit
from argparse import ArgumentParser
from signal import signal, SIGTERM, SIGINT, SIGQUIT, SIGHUP


def read_path(path):
    """
    """
    try:
        fd[path].seek(0)
    except ValueError:
        try:
            fd[path] = open(path, 'rb', buffering=0)
        except FileNotFoundError as e:
            log(e)
            return None
    except KeyError:
        try:
            fd[path] = open(path, 'rb', buffering=0)
        except FileNotFoundError as e:
            log(e)
            return None
    try:
        return fd[path].read(99999).decode()
    except OSError as e:
        log(e)
        fd[path].close()
        return None


def form1(num):
    """
    """
    s = str(num).split('.')
    return '{}.{:0<2}'.format(s[0], s[1])


def form2(num):
    """
    """
    s = str(round(num, 1)).split('.')
    return '{}.{:0<1}'.format(s[0], s[1])


def signal_handler(signum, frame):
    """
    """
    def signal_handler_inner(signum, frame):
        pass

    for i in sig_list:
        signal(i, signal_handler_inner)

    if len(fd) > 0:
        for f in fd:
            fd[f].close()

    if signum == SIGINT:
        print('')

    lpd = len(peaks_dict)

    if lpd == 15:  # mode 1

        log('=================================')
        log('Peak values:  avg10  avg60 avg300')

        log('-----------  ------ ------ ------')

        log('some cpu     {:>6} {:>6} {:>6}'.format(
            form1(peaks_dict['c_some_avg10']),
            form1(peaks_dict['c_some_avg60']),
            form1(peaks_dict['c_some_avg300']),
        ))

        log('-----------  ------ ------ ------')

        log('some io      {:>6} {:>6} {:>6}'.format(
            form1(peaks_dict['i_some_avg10']),
            form1(peaks_dict['i_some_avg60']),
            form1(peaks_dict['i_some_avg300']),
        ))

        log('full io      {:>6} {:>6} {:>6}'.format(
            form1(peaks_dict['i_full_avg10']),
            form1(peaks_dict['i_full_avg60']),
            form1(peaks_dict['i_full_avg300']),
        ))

        log('-----------  ------ ------ ------')

        log('some memory  {:>6} {:>6} {:>6}'.format(
            form1(peaks_dict['m_some_avg10']),
            form1(peaks_dict['m_some_avg60']),
            form1(peaks_dict['m_some_avg300']),
        ))

        log('full memory  {:>6} {:>6} {:>6}'.format(
            form1(peaks_dict['m_full_avg10']),
            form1(peaks_dict['m_full_avg60']),
            form1(peaks_dict['m_full_avg300']),
        ))

    if lpd == 5:  # mode 2

        log('----- | ----- ----- | ----- ----- | --------')

        log('{:>5} | {:>5} {:>5} | {:>5} {:>5} | peaks'.format(

            form2(peaks_dict['avg_cs']),

            form2(peaks_dict['avg_is']),
            form2(peaks_dict['avg_if']),

            form2(peaks_dict['avg_ms']),
            form2(peaks_dict['avg_mf'])

        ))

    if separate_log:
        logging.info('')

    exit()


def cgroup2_root():
    """
    """
    with open(mounts) as f:
        for line in f:
            if cgroup2_separator in line:
                return line.partition(cgroup2_separator)[0].partition(' ')[2]


def mlockall():
    """
    """
    MCL_CURRENT = 1
    MCL_FUTURE = 2
    MCL_ONFAULT = 4

    libc = CDLL('libc.so.6', use_errno=True)
    result = libc.mlockall(MCL_CURRENT | MCL_FUTURE | MCL_ONFAULT)

    if result != 0:
        result = libc.mlockall(MCL_CURRENT | MCL_FUTURE)
        if result != 0:
            log('WARNING: cannot lock all memory: [Errno {}]'.format(result))
        else:
            log('All memory locked with MCL_CURRENT | MCL_FUTURE')
    else:
        log('All memory locked with MCL_CURRENT | MCL_FUTURE | MCL_ONFAULT')


def psi_file_mem_to_metrics0(psi_path):
    """
    """
    with open(psi_path) as f:
        psi_list = f.readlines()
    some_list, full_list = psi_list[0].split(' '), psi_list[1].split(' ')
    some_avg10 = some_list[1].split('=')[1]
    some_avg60 = some_list[2].split('=')[1]
    some_avg300 = some_list[3].split('=')[1]
    full_avg10 = full_list[1].split('=')[1]
    full_avg60 = full_list[2].split('=')[1]
    full_avg300 = full_list[3].split('=')[1]
    return (some_avg10, some_avg60, some_avg300,
            full_avg10, full_avg60, full_avg300)


def psi_file_mem_to_metrics(psi_path):
    """
    """
    foo = read_path(psi_path)

    if foo is None:
        return None

    try:

        psi_list = foo.split('\n')

        some_list, full_list = psi_list[0].split(' '), psi_list[1].split(' ')
        some_avg10 = some_list[1].split('=')[1]
        some_avg60 = some_list[2].split('=')[1]
        some_avg300 = some_list[3].split('=')[1]
        full_avg10 = full_list[1].split('=')[1]
        full_avg60 = full_list[2].split('=')[1]
        full_avg300 = full_list[3].split('=')[1]
        return (some_avg10, some_avg60, some_avg300,
                full_avg10, full_avg60, full_avg300)

    except Exception as e:
        log('{}'.format(e))
        return None


def psi_file_cpu_to_metrics(psi_path):
    """
    """
    foo = read_path(psi_path)

    if foo is None:
        return None

    try:

        psi_list = foo.split('\n')

        some_list = psi_list[0].split(' ')
        some_avg10 = some_list[1].split('=')[1]
        some_avg60 = some_list[2].split('=')[1]
        some_avg300 = some_list[3].split('=')[1]
        return (some_avg10, some_avg60, some_avg300)

    except Exception as e:
        log('{}'.format(e))
        return None


def psi_file_mem_to_total(psi_path):
    """
    """
    foo = read_path(psi_path)

    if foo is None:
        return None

    try:

        psi_list = foo.split('\n')

        some_list, full_list = psi_list[0].split(' '), psi_list[1].split(' ')
        some_total = some_list[4].split('=')[1]
        full_total = full_list[4].split('=')[1]

        return int(some_total), int(full_total)

    except Exception as e:
        log('{}'.format(e))
        return None


def psi_file_cpu_to_total(psi_path):
    """
    """
    foo = read_path(psi_path)

    if foo is None:
        return None

    try:

        psi_list = foo.split('\n')

        some_list = psi_list[0].split(' ')
        some_total = some_list[4].split('=')[1]

        return int(some_total)

    except Exception as e:
        log('{}'.format(e))
        return None


def print_head_1():
    """
    """
    log('===================================================================='
        '==============================================')
    log('        cpu          ||                     io                      '
        '||                   memory')
    log('==================== || =========================================== '
        '|| ===========================================')
    log('        some         ||         some         |         full         '
        '||         some         |         full')
    log('-------------------- || -------------------- | -------------------- '
        '|| -------------------- | --------------------')
    log(' avg10  avg60 avg300 ||  avg10  avg60 avg300 |  avg10  avg60 avg300 '
        '||  avg10  avg60 avg300 |  avg10  avg60 avg300')
    log('------ ------ ------ || ------ ------ ------ | ------ ------ ------ '
        '|| ------ ------ ------ | ------ ------ ------')


def print_head_2():
    """
    """
    log('======|=============|=============|')
    log(' cpu  |      io     |    memory   |')
    log('----- | ----------- | ----------- |')
    log(' some |  some  full |  some  full | interval')
    log('----- | ----- ----- | ----- ----- | --------')


def log(*msg):
    """
    """
    if not SUPPRESS_OUTPUT:
        print(*msg)
    if separate_log:
        logging.info(*msg)


def log_head(*msg):
    """
    """
    print(*msg)
    if separate_log:
        logging.info(*msg)


##############################################################################


parser = ArgumentParser()

parser.add_argument(
    '-t',
    '--target',
    help="""target (cgroup_v2 or SYSTEM_WIDE)""",
    default='SYSTEM_WIDE',
    type=str
)


parser.add_argument(
    '-i',
    '--interval',
    help="""interval in sec""",
    default=2,
    type=float
)


parser.add_argument(
    '-l',
    '--log',
    help="""path to log file""",
    default=None,
    type=str
)


parser.add_argument(
    '-m',
    '--mode',
    help="""mode (1 or 2)""",
    default='1',
    type=str
)


parser.add_argument(
    '-s',
    '--suppress-output',
    help="""suppress output""",
    default='False',
    type=str
)


args = parser.parse_args()
target = args.target
mode = args.mode
interval = args.interval
log_file = args.log
suppress_output = args.suppress_output

if target != 'SYSTEM_WIDE':
    target = '/' + target.strip('/')


##############################################################################


if log_file is None:
    separate_log = False
else:
    separate_log = True
    import logging


sig_list = [SIGTERM, SIGINT, SIGQUIT, SIGHUP]

for i in sig_list:
    signal(i, signal_handler)

if separate_log:
    try:
        logging.basicConfig(
            filename=log_file,
            level=logging.INFO,
            format="%(asctime)s: %(message)s")
    except Exception as e:
        print(e)
        exit(1)


if suppress_output == 'False':
    SUPPRESS_OUTPUT = False
elif suppress_output == 'True':
    SUPPRESS_OUTPUT = True
else:
    log_head('error: argument -s/--suppress-output: valid values are '
             'False and True')
    exit(1)


if log_file is not None:
    logstring = 'log file: {}, '.format(log_file)
else:
    logstring = 'log file is not set, '


if interval < 1:
    log_head('error: argument -i/--interval: the value must be greater than or'
             ' equal to 1')
    exit(1)


if not (mode == '1' or mode == '2'):
    log_head('ERROR: invalid mode. Valid values are 1 and 2. Exit.')
    exit(1)


try:
    psi_file_mem_to_metrics0('/proc/pressure/memory')
except Exception as e:
    log_head('ERROR: {}'.format(e))
    log_head('PSI metrics are not provided by the kernel. Exit.')
    exit(1)


log_head('Starting psi2log, target: {}, mode: {}, interval: {} sec, {}suppress'
         ' output: {}'.format(
             target, mode, round(interval, 3), logstring, suppress_output))


fd = dict()


if target == 'SYSTEM_WIDE':
    system_wide = True
    source_dir = '/proc/pressure'
    cpu_file = '/proc/pressure/cpu'
    io_file = '/proc/pressure/io'
    memory_file = '/proc/pressure/memory'
    log_head('PSI source dir: /proc/pressure/, source files: cpu, io, memory')
else:
    system_wide = False
    mounts = '/proc/mounts'
    cgroup2_separator = ' cgroup2 rw,'
    cgroup2_mountpoint = cgroup2_root()

    if cgroup2_mountpoint is None:
        log('ERROR: unified cgroup hierarchy is not mounted, exit')
        exit(1)

    source_dir = cgroup2_mountpoint + target
    cpu_file = source_dir + '/cpu.pressure'
    io_file = source_dir + '/io.pressure'
    memory_file = source_dir + '/memory.pressure'
    log_head('PSI source dir: {}{}/, source files: cpu.pressure, io.pressure,'
             ' memory.pressure'.format(cgroup2_mountpoint, target))


abnormal_interval = 1.01 * interval


peaks_dict = dict()


mlockall()


if mode == '2':

    print_head_2()

    try:

        total_cs0 = psi_file_cpu_to_total(cpu_file)
        total_is0, total_if0 = psi_file_mem_to_total(io_file)
        total_ms0, total_mf0 = psi_file_mem_to_total(memory_file)
        monotonic0 = monotonic()
        stdout.flush()
        sleep(interval)

    except TypeError:
        stdout.flush()
        sleep(interval)

    while True:

        try:

            total_cs1 = psi_file_cpu_to_total(cpu_file)
            total_is1, total_if1 = psi_file_mem_to_total(io_file)
            total_ms1, total_mf1 = psi_file_mem_to_total(memory_file)
            monotonic1 = monotonic()
            dm = monotonic1 - monotonic0

            if dm > abnormal_interval and dm - interval > 0.05:
                log('WARNING: abnormal interval ({} sec), metrics may be prov'
                    'ided incorrect'.format(round(dm, 3)))

            monotonic0 = monotonic1

        except TypeError:
            stdout.flush()
            sleep(interval)
            continue

        dtotal_cs = total_cs1 - total_cs0
        avg_cs = dtotal_cs / dm / 10000
        if 'avg_cs' not in peaks_dict or peaks_dict['avg_cs'] < avg_cs:
            peaks_dict['avg_cs'] = avg_cs
        total_cs0 = total_cs1

        dtotal_is = total_is1 - total_is0
        avg_is = dtotal_is / dm / 10000
        if 'avg_is' not in peaks_dict or peaks_dict['avg_is'] < avg_is:
            peaks_dict['avg_is'] = avg_is
        total_is0 = total_is1

        dtotal_if = total_if1 - total_if0
        avg_if = dtotal_if / dm / 10000
        if 'avg_if' not in peaks_dict or peaks_dict['avg_if'] < avg_if:
            peaks_dict['avg_if'] = avg_if
        total_if0 = total_if1

        dtotal_ms = total_ms1 - total_ms0
        avg_ms = dtotal_ms / dm / 10000
        if 'avg_ms' not in peaks_dict or peaks_dict['avg_ms'] < avg_ms:
            peaks_dict['avg_ms'] = avg_ms
        total_ms0 = total_ms1

        dtotal_mf = total_mf1 - total_mf0
        avg_mf = dtotal_mf / dm / 10000
        if 'avg_mf' not in peaks_dict or peaks_dict['avg_mf'] < avg_mf:
            peaks_dict['avg_mf'] = avg_mf
        total_mf0 = total_mf1

        log('{:>5} | {:>5} {:>5} | {:>5} {:>5} | {}'.format(

            round(avg_cs, 1),

            round(avg_is, 1),
            round(avg_if, 1),

            round(avg_ms, 1),
            round(avg_mf, 1),

            round(dm, 3)

        ))

        stdout.flush()
        sleep(interval)


print_head_1()


while True:

    try:

        (c_some_avg10, c_some_avg60, c_some_avg300
         ) = psi_file_cpu_to_metrics(cpu_file)

        (i_some_avg10, i_some_avg60, i_some_avg300,
         i_full_avg10, i_full_avg60, i_full_avg300
         ) = psi_file_mem_to_metrics(io_file)

        (m_some_avg10, m_some_avg60, m_some_avg300,
         m_full_avg10, m_full_avg60, m_full_avg300
         ) = psi_file_mem_to_metrics(memory_file)

    except TypeError:
        stdout.flush()
        sleep(interval)
        continue

    log('{:>6} {:>6} {:>6} || {:>6} {:>6} {:>6} | {:>6} {:>6} {:>6} || {:>6}'
        ' {:>6} {:>6} | {:>6} {:>6} {:>6}'.format(

            c_some_avg10, c_some_avg60, c_some_avg300,

            i_some_avg10, i_some_avg60, i_some_avg300,
            i_full_avg10, i_full_avg60, i_full_avg300,

            m_some_avg10, m_some_avg60, m_some_avg300,
            m_full_avg10, m_full_avg60, m_full_avg300

        ))

    c_some_avg10 = float(c_some_avg10)
    if ('c_some_avg10' not in peaks_dict or
            peaks_dict['c_some_avg10'] < c_some_avg10):
        peaks_dict['c_some_avg10'] = c_some_avg10

    c_some_avg60 = float(c_some_avg60)
    if ('c_some_avg60' not in peaks_dict or
            peaks_dict['c_some_avg60'] < c_some_avg60):
        peaks_dict['c_some_avg60'] = c_some_avg60

    c_some_avg300 = float(c_some_avg300)
    if ('c_some_avg300' not in peaks_dict or
            peaks_dict['c_some_avg300'] < c_some_avg300):
        peaks_dict['c_some_avg300'] = c_some_avg300

    #######################################################################

    i_some_avg10 = float(i_some_avg10)
    if ('i_some_avg10' not in peaks_dict or
            peaks_dict['i_some_avg10'] < i_some_avg10):
        peaks_dict['i_some_avg10'] = i_some_avg10

    i_some_avg60 = float(i_some_avg60)
    if ('i_some_avg60' not in peaks_dict or
            peaks_dict['i_some_avg60'] < i_some_avg60):
        peaks_dict['i_some_avg60'] = i_some_avg60

    i_some_avg300 = float(i_some_avg300)
    if ('i_some_avg300' not in peaks_dict or
            peaks_dict['i_some_avg300'] < i_some_avg300):
        peaks_dict['i_some_avg300'] = i_some_avg300

    i_full_avg10 = float(i_full_avg10)
    if ('i_full_avg10' not in peaks_dict or
            peaks_dict['i_full_avg10'] < i_full_avg10):
        peaks_dict['i_full_avg10'] = i_full_avg10

    i_full_avg60 = float(i_full_avg60)
    if ('i_full_avg60' not in peaks_dict or
            peaks_dict['i_full_avg60'] < i_full_avg60):
        peaks_dict['i_full_avg60'] = i_full_avg60

    i_full_avg300 = float(i_full_avg300)
    if ('i_full_avg300' not in peaks_dict or
            peaks_dict['i_full_avg300'] < i_full_avg300):
        peaks_dict['i_full_avg300'] = i_full_avg300

    #######################################################################

    m_some_avg10 = float(m_some_avg10)
    if ('m_some_avg10' not in peaks_dict or
            peaks_dict['m_some_avg10'] < m_some_avg10):
        peaks_dict['m_some_avg10'] = m_some_avg10

    m_some_avg60 = float(m_some_avg60)
    if ('m_some_avg60' not in peaks_dict or
            peaks_dict['m_some_avg60'] < m_some_avg60):
        peaks_dict['m_some_avg60'] = m_some_avg60

    m_some_avg300 = float(m_some_avg300)
    if ('m_some_avg300' not in peaks_dict or
            peaks_dict['m_some_avg300'] < m_some_avg300):
        peaks_dict['m_some_avg300'] = m_some_avg300

    m_full_avg10 = float(m_full_avg10)
    if ('m_full_avg10' not in peaks_dict or
            peaks_dict['m_full_avg10'] < m_full_avg10):
        peaks_dict['m_full_avg10'] = m_full_avg10

    m_full_avg60 = float(m_full_avg60)
    if ('m_full_avg60' not in peaks_dict or
            peaks_dict['m_full_avg60'] < m_full_avg60):
        peaks_dict['m_full_avg60'] = m_full_avg60

    m_full_avg300 = float(m_full_avg300)
    if ('m_full_avg300' not in peaks_dict or
            peaks_dict['m_full_avg300'] < m_full_avg300):
        peaks_dict['m_full_avg300'] = m_full_avg300

    stdout.flush()
    sleep(interval)
