#!/usr/bin/env python

# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import collections
import errno
import os
import re
import sys
import tokenize as std_tokenize
from pathlib import Path

import isort
import isort.files
import unasync


ROOT_DIR = Path(__file__).parents[1].absolute()
ASYNC_DIR = ROOT_DIR / "neo4j" / "_async"
SYNC_DIR = ROOT_DIR / "neo4j" / "_sync"
ASYNC_UNIT_TEST_DIR = ROOT_DIR / "tests" / "unit" / "async_"
SYNC_UNIT_TEST_DIR = ROOT_DIR / "tests" / "unit" / "sync"
ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "tests" / "integration" / "async_"
SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "tests" / "integration" / "sync"
ASYNC_TESTKIT_BACKEND_DIR = ROOT_DIR / "testkitbackend" / "_async"
SYNC_TESTKIT_BACKEND_DIR = ROOT_DIR / "testkitbackend" / "_sync"
UNASYNC_SUFFIX = ".unasync"

PY_FILE_EXTENSIONS = {".py", ".pyi"}


# copy from unasync for customization -----------------------------------------
# https://github.com/python-trio/unasync
# License: MIT and Apache2


Token = collections.namedtuple(
    "Token", ["type", "string", "start", "end", "line"]
)


def _makedirs_existok(dir):
    try:
        os.makedirs(dir)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def _get_tokens(f):
    if sys.version_info[0] == 2:
        for tok in std_tokenize.generate_tokens(f.readline):
            type_, string, start, end, line = tok
            yield Token(type_, string, start, end, line)
    else:
        for tok in std_tokenize.tokenize(f.readline):
            if tok.type == std_tokenize.ENCODING:
                continue
            yield tok


def _tokenize(f):
    last_end = (1, 0)
    for tok in _get_tokens(f):
        if last_end[0] < tok.start[0]:
            yield "", std_tokenize.STRING, " \\\n"
            last_end = (tok.start[0], 0)

        space = ""
        if tok.start > last_end:
            assert tok.start[0] == last_end[0]
            space = " " * (tok.start[1] - last_end[1])
        yield space, tok.type, tok.string

        last_end = tok.end
        if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]:
            last_end = (tok.end[0] + 1, 0)


def _untokenize(tokens):
    return "".join(space + tokval for space, tokval in tokens)


# end of copy -----------------------------------------------------------------


class CustomRule(unasync.Rule):
    def __init__(self, *args, **kwargs):
        super(CustomRule, self).__init__(*args, **kwargs)
        self.out_files = []
        # it's not pretty, but it works
        # typing.Awaitable[...] -> typing.Union[...]
        self.token_replacements["Awaitable"] = "Union"

    def _unasync_tokens(self, tokens):
        # copy from unasync to hook into string handling
        # https://github.com/python-trio/unasync
        # License: MIT and Apache2
        # TODO __await__, ...?
        used_space = None
        for space, toknum, tokval in tokens:
            if tokval in ["async", "await"]:
                # When removing async or await, we want to use the whitespace
                # that was before async/await before the next token so that
                # `print(await stuff)` becomes `print(stuff)` and not
                # `print( stuff)`
                used_space = space
            else:
                if toknum == std_tokenize.NAME:
                    tokval = self._unasync_name(tokval)
                elif toknum == std_tokenize.STRING:
                    if tokval[0] == tokval[1] and len(tokval) > 2:
                        # multiline string (`"""..."""` or `'''...'''`)
                        left_quote, name, right_quote = \
                            tokval[:3], tokval[3:-3], tokval[-3:]
                    else:
                        # simple string (`"..."` or `'...'`)
                        left_quote, name, right_quote = \
                            tokval[:1], tokval[1:-1], tokval[-1:]
                    tokval = \
                        left_quote + self._unasync_string(name) + right_quote
                if used_space is None:
                    used_space = space
                yield (used_space, tokval)
                used_space = None

    def _unasync_string(self, name):
        start = 0
        end = 1
        out = ""
        while end < len(name):
            sub_name = name[start:end]
            if sub_name.isidentifier():
                end += 1
            else:
                if end == start + 1:
                    out += sub_name
                    start += 1
                    end += 1
                else:
                    out += self._unasync_name(name[start:(end - 1)])
                    start = end - 1

        sub_name = name[start:]
        if sub_name.isidentifier():
            out += self._unasync_name(name[start:])
        else:
            out += sub_name

        # very boiled down unasync version that removes "async" and "await"
        # substrings.
        out = re.subn(r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out,
                      flags=re.MULTILINE)[0]
        # Convert doc-reference names from 'async-xyz' to 'xyz'
        out = re.subn(r":ref:`async-", ":ref:`", out)[0]
        return out

    def _unasync_prefix(self, name):
        # Convert class names from 'AsyncXyz' to 'Xyz'
        if len(name) > 5 and name.startswith("Async") and name[5].isupper():
            return name[5:]
        # Convert class names from '_AsyncXyz' to '_Xyz'
        elif len(name) > 6 and name.startswith("_Async") and name[6].isupper():
            return "_" + name[6:]
        # Convert variable/method/function names from 'async_xyz' to 'xyz'
        elif len(name) > 6 and name.startswith("async_"):
            return name[6:]
        return name

    def _unasync_name(self, name):
        # copy from unasync to customize renaming rules
        # https://github.com/python-trio/unasync
        # License: MIT and Apache2
        if name in self.token_replacements:
            return self.token_replacements[name]
        return self._unasync_prefix(name)

    def _unasync_file(self, filepath):
        # copy from unasync to append file suffix to out path
        # https://github.com/python-trio/unasync
        # License: MIT and Apache2
        with open(filepath, "rb") as f:
            write_kwargs = {}
            if sys.version_info[0] >= 3:
                encoding, _ = std_tokenize.detect_encoding(f.readline)
                write_kwargs["encoding"] = encoding
                f.seek(0)
            tokens = _tokenize(f)
            tokens = self._unasync_tokens(tokens)
            result = _untokenize(tokens)
            outfile_path = filepath.replace(self.fromdir, self.todir)
            outfile_path += UNASYNC_SUFFIX
            self.out_files.append(outfile_path)
            _makedirs_existok(os.path.dirname(outfile_path))
            with open(outfile_path, "w", **write_kwargs) as f:
                print(result, file=f, end="")


def apply_unasync(files):
    """Generate sync code from async code."""

    additional_main_replacements = {}
    additional_test_replacements = {
        "_async": "_sync",
        "mark_async_test": "mark_sync_test",
        "assert_awaited_once": "assert_called_once",
        "assert_awaited_once_with": "assert_called_once_with",
        "await_count": "call_count",
    }
    additional_testkit_backend_replacements = {}
    rules = [
        CustomRule(
            fromdir=str(ASYNC_DIR),
            todir=str(SYNC_DIR),
            additional_replacements=additional_main_replacements,
        ),
        CustomRule(
            fromdir=str(ASYNC_UNIT_TEST_DIR),
            todir=str(SYNC_UNIT_TEST_DIR),
            additional_replacements=additional_test_replacements,
        ),
        CustomRule(
            fromdir=str(ASYNC_INTEGRATION_TEST_DIR),
            todir=str(SYNC_INTEGRATION_TEST_DIR),
            additional_replacements=additional_test_replacements,
        ),
        CustomRule(
            fromdir=str(ASYNC_TESTKIT_BACKEND_DIR),
            todir=str(SYNC_TESTKIT_BACKEND_DIR),
            additional_replacements=additional_testkit_backend_replacements,
        ),
    ]

    if not files:
        paths = list(ASYNC_DIR.rglob("*"))
        paths += list(ASYNC_UNIT_TEST_DIR.rglob("*"))
        paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*"))
        paths += list(ASYNC_TESTKIT_BACKEND_DIR.rglob("*"))
    else:
        paths = [ROOT_DIR / Path(f) for f in files]
    filtered_paths = []
    for path in paths:
        if path.suffix in PY_FILE_EXTENSIONS:
            filtered_paths.append(path)

    unasync.unasync_files(map(str, filtered_paths), rules)

    return [Path(path) for rule in rules for path in rule.out_files]


def apply_isort(paths):
    """Sort imports in generated sync code.

    Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical
    order of the import can change.
    """
    isort_config = isort.Config(settings_path=str(ROOT_DIR), quiet=True)

    for path in paths:
        isort.file(str(path), config=isort_config)

    return paths


def apply_changes(paths):
    def files_equal(path1, path2):
        with open(path1, "rb") as f1:
            with open(path2, "rb") as f2:
                data1 = f1.read(1024)
                data2 = f2.read(1024)
                while data1 or data2:
                    if data1 != data2:
                        changed_paths[path1] = path2
                        return False
                    data1 = f1.read(1024)
                    data2 = f2.read(1024)
        return True

    changed_paths = {}

    for in_path in paths:
        out_path = Path(str(in_path)[:-len(UNASYNC_SUFFIX)])
        if not out_path.is_file():
            changed_paths[in_path] = out_path
            continue
        if not files_equal(in_path, out_path):
            changed_paths[in_path] = out_path
            continue
        in_path.unlink()

    for in_path, out_path in changed_paths.items():
        in_path.replace(out_path)

    return list(changed_paths.values())


def main():
    files = None
    if len(sys.argv) >= 1:
        files = sys.argv[1:]
    paths = apply_unasync(files)
    paths = apply_isort(paths)
    changed_paths = apply_changes(paths)

    if changed_paths:
        for path in changed_paths:
            print("Updated " + str(path))
        exit(1)


if __name__ == "__main__":
    main()
