#!/usr/bin/env python3
#
# This file is licensed under LGPL v2.1 to be compatible
# with the parts used from WineVulkan in vulkan_helpers.py.
# See the top of it's file for Wine's license information.
#
#   Copyright (C) 2022 Joshua Ashton
#   
#   This library is free software; you can redistribute it and/or
#   modify it under the terms of the GNU Lesser General Public
#   License as published by the Free Software Foundation; either
#   version 2.1 of the License, or (at your option) any later version.
#   
#   This library is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#   Lesser General Public License for more details.
#   
#   You should have received a copy of the GNU Lesser General Public
#   License along with this library; if not, write to the Free Software
#   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
#

import argparse
import os
import urllib.request

VK_XML_VERSION = "1.3.253"

from vulkan_helpers import *

def remove_vk_prefix(name):
    if name.startswith("vk"):
        return name[len("vk"):]
    return name

def is_proc_addr_func(name):
    return name == "vkGetInstanceProcAddr" or name == "vkGetPhysicalDeviceProcAddr" or name == "vkGetDeviceProcAddr"

def write_include(out, filename):
    with open("inc/" + filename, "r") as f:
        contents = f.read()
        out.write(contents)
        if contents[-1] != '\n':
            out.write('\n')

class VkGenerator(object):
    def __init__(self, registry):
        self.registry = registry

    def find_ext_info(self, lookup):
        for ext in self.registry.extensions:
            if ext["name"] == lookup:
                return ext
        return None

    def get_object_platform(self, func):
        platform = None
        for ext in func.extensions:
            ext_info = self.find_ext_info(ext)
            if ext_info == None:
                return "unsupported"
            platform = ext_info["platform"]
            if platform == None:
                return None
        return platform

    def get_object_platform_define(self, func):
        platform = self.get_object_platform(func)
        if platform != None:
            if platform in self.registry.platforms:
                return self.registry.platforms[platform]["protect"]
            else:
                return "VKROOTS_PLATFORM_UNSUPPORTED"
        return None

    def print_object_platform_ifdef(self, f, func):
        plat_define = self.get_object_platform_define(func)
        if plat_define != None:
            f.write(f"#ifdef {plat_define}\n");

    def print_object_platform_endif(self, f, func):
        plat_define = self.get_object_platform_define(func)
        if plat_define != None:
            f.write(f"#endif\n");

    def write_dispatch_class(self, f, dispatch_type, dispatch_name, procaddr_type, procaddr_name):
        procaddr_normalized = f"Get{dispatch_type}ProcAddr"
        procaddr = f"vk{procaddr_normalized}"

        f.write(f"  class Vk{dispatch_type}Dispatch {{\n")
        f.write( "  public:\n")
        additional_args = ""
        if dispatch_type == "PhysicalDevice":
            additional_args += ", const VkInstanceDispatch* pInstanceDispatch"
        if dispatch_type == "Device":
            additional_args += ", VkPhysicalDevice PhysicalDevice, const VkPhysicalDeviceDispatch* pPhysicalDeviceDispatch, const VkDeviceCreateInfo* pCreateInfo"
        f.write(f"    Vk{dispatch_type}Dispatch(PFN_{procaddr} Next{procaddr_normalized}, Vk{procaddr_type} {procaddr_name}{additional_args}) {{\n")
        if dispatch_type == "Instance":
            f.write(f"      this->Instance = instance;\n")
        if dispatch_type == "PhysicalDevice":
            f.write(f"      this->Instance = instance;\n")
            f.write(f"      this->pInstanceDispatch = pInstanceDispatch;\n")
            f.write(f"      this->GetPhysicalDeviceProcAddr = NextGetPhysicalDeviceProcAddr;\n")
        if dispatch_type == "Device":
            f.write(f"      this->PhysicalDevice = PhysicalDevice;\n")
            f.write(f"      this->Device = device;\n")
            f.write(f"      this->pPhysicalDeviceDispatch = pPhysicalDeviceDispatch;\n")
            f.write(f"      for (uint32_t i = 0; i < pCreateInfo->queueCreateInfoCount; i++) {{\n")
            f.write(f"        VkDeviceQueueCreateInfo queueInfo = pCreateInfo->pQueueCreateInfos[i];\n")
            f.write(f"        queueInfo.pNext = nullptr;\n")
            f.write(f"        DeviceQueueInfos.push_back(queueInfo);\n")
            f.write(f"      }}\n")
        for func in self.registry.funcs.values():
            if not func.is_required():
                continue
            if func.get_func_type() == dispatch_type:
                func_name_normalized = remove_vk_prefix(func.name)
                self.print_object_platform_ifdef(f, func);
                if func.name in ("vkGetInstanceProcAddr", "vkGetDeviceProcAddr"):
                    f.write(f"      {func_name_normalized} = Next{procaddr_normalized};\n")
                elif func.name in ("vkDestroyInstance", "vkDestroyDevice"):
                    f.write(f"      {func_name_normalized}Real = (PFN_{func.name}) Next{procaddr_normalized}({procaddr_name}, \"{func.name}\");\n")
                    f.write(f"      {func_name_normalized} = (PFN_{func.name}) {func_name_normalized}Wrapper;\n")
                else:
                    f.write(f"      {func_name_normalized} = (PFN_{func.name}) Next{procaddr_normalized}({procaddr_name}, \"{func.name}\");\n")
                self.print_object_platform_endif(f, func);
        f.write(f"    }}\n\n")
        f.write( "    mutable uint64_t UserData = 0;\n")
        if dispatch_type == "Instance":
            f.write( "    VkInstance Instance;\n")
        if dispatch_type == "PhysicalDevice":
            f.write( "    VkInstance Instance;\n")
            f.write( "    const VkInstanceDispatch* pInstanceDispatch;\n")
            f.write( "    PFN_GetPhysicalDeviceProcAddr GetPhysicalDeviceProcAddr;\n")
        if dispatch_type == "Device":
            f.write( "    VkDevice Device;\n")
            f.write( "    VkPhysicalDevice PhysicalDevice;\n")
            f.write( "    const VkPhysicalDeviceDispatch* pPhysicalDeviceDispatch;\n")
            f.write( "    std::vector<VkDeviceQueueCreateInfo> DeviceQueueInfos;\n") # mutable hack TODO: remove
        for func in self.registry.funcs.values():
            if not func.is_required():
                continue
            if func.get_func_type() == dispatch_type:
                func_name_normalized = remove_vk_prefix(func.name)
                self.print_object_platform_ifdef(f, func);
                f.write(f"    PFN_{func.name} {func_name_normalized};\n")
                self.print_object_platform_endif(f, func);
        f.write("  private:\n")
        if dispatch_type == "Instance" or dispatch_type == "Device":
            f.write(f"    PFN_vkDestroy{dispatch_type} Destroy{dispatch_type}Real;\n");
            f.write(f"    static void Destroy{dispatch_type}Wrapper(Vk{dispatch_type} object, const VkAllocationCallbacks* pAllocator) {{\n")
            f.write(f"      auto dispatch = vkroots::tables::Lookup{dispatch_type}Dispatch(object);\n")
            f.write(f"      auto destroyFunc = dispatch->Destroy{dispatch_type}Real;\n")
            f.write(f"      vkroots::tables::DestroyDispatchTable(object);\n")
            f.write(f"      destroyFunc(object, pAllocator);\n")
            f.write(f"    }}\n")
        f.write( "  };\n\n")
        if dispatch_type == "PhysicalDevice":
            f.write( "  namespace tables {\n")
            f.write( "    static inline const VkInstanceDispatch* LookupInstanceDispatch(VkPhysicalDevice physicalDevice) { return PhysicalDeviceDispatches.find(physicalDevice)->pInstanceDispatch; }\n")
            f.write( "  }\n")

    def write_dispatch_funcs(self, f, dispatch_type, dispatch_name, procaddr_type, procaddr_name):
        for func in self.registry.funcs.values():
            if not func.is_required():
                continue
            if func.get_func_type() == dispatch_type and func.name != "vkGetDeviceProcAddr":
                self.print_object_platform_ifdef(f, func);
                func_name_normalized = remove_vk_prefix(func.name)
                params   = ", ".join([p.definition() for p in func.params])
                args     = ", ".join([p.name for p in func.params])
                return_v = f"{func.type} ret = " if func.type != "void" else ""
                f.write( "  template <typename InstanceOverrides, typename PhysicalDeviceOverrides, typename DeviceOverrides>\n")
                f.write(f"  static {func.type} wrap_{func_name_normalized}({params}) {{\n")
                if func.name == "vkCreateInstance":
                    f.write(f"    VkInstanceProcAddrFuncs instanceProcAddrFuncs;\n");
                    f.write(f"    VkResult procAddrRes = GetProcAddrs(pCreateInfo, &instanceProcAddrFuncs);\n");
                    f.write(f"    if (procAddrRes != VK_SUCCESS)\n");
                    f.write(f"      return procAddrRes;\n");
                    f.write(f"    PFN_vkCreateInstance createInstanceProc = (PFN_vkCreateInstance) instanceProcAddrFuncs.NextGetInstanceProcAddr(NULL, \"vkCreateInstance\");\n")
                    f.write(f"    {return_v}{dispatch_type}Overrides::{func_name_normalized}(createInstanceProc, {args});\n")
                elif func.name == "vkCreateDevice":
                    f.write(f"    const Vk{dispatch_type}Dispatch* dispatch = tables::Lookup{dispatch_type}Dispatch({func.params[0].name});\n")
                    f.write(f"    PFN_vkGetDeviceProcAddr deviceProcAddr;\n")
                    f.write(f"    VkResult procAddrRes = GetProcAddrs(pCreateInfo, &deviceProcAddr);\n")
                    f.write(f"    if (procAddrRes != VK_SUCCESS)\n");
                    f.write(f"      return procAddrRes;\n");
                    f.write(f"    {return_v}{dispatch_type}Overrides::{func_name_normalized}(dispatch, {args});\n")
                else:
                    f.write(f"    const Vk{dispatch_type}Dispatch* dispatch = tables::Lookup{dispatch_type}Dispatch({func.params[0].name});\n")
                    f.write(f"    {return_v}{dispatch_type}Overrides::{func_name_normalized}(dispatch, {args});\n")
                if func.name == "vkCreateInstance":
                    f.write(f"    if (ret == VK_SUCCESS)\n")
                    f.write(f"      tables::CreateDispatchTable(instanceProcAddrFuncs.NextGetInstanceProcAddr, instanceProcAddrFuncs.NextGetPhysicalDeviceProcAddr, *pInstance);\n")
                if func.name == "vkCreateDevice":
                    f.write(f"    if (ret == VK_SUCCESS)\n")
                    f.write(f"      tables::CreateDispatchTable(pCreateInfo, deviceProcAddr, physicalDevice, *pDevice);\n")
                if func.type != "void":
                    f.write(f"    return ret;\n")
                f.write("  }\n\n")
                self.print_object_platform_endif(f, func);

        if dispatch_type == "Instance":
            write_include(f, "vkroots_implicit_createinstance.h")

        if dispatch_type == "Instance":
            write_include(f, "vkroots_implicit_destroyinstance.h")

        if dispatch_type == "PhysicalDevice":
            write_include(f, "vkroots_implicit_createdevice.h")

        if dispatch_type == "Device":
            write_include(f, "vkroots_implicit_destroydevice.h")

    def write_dispatch_impls(self, f, dispatch_type, dispatch_name, procaddr_type, procaddr_name):
        procaddr_normalized = f"Get{dispatch_type}ProcAddr"
        procaddr = f"vk{procaddr_normalized}"

        f.write( "  template <typename InstanceOverrides, typename PhysicalDeviceOverrides, typename DeviceOverrides>\n")
        f.write(f"  static PFN_vkVoidFunction Get{dispatch_type}ProcAddr(Vk{procaddr_type} {procaddr_name}, const char* name) {{\n")
        f.write(f"    const Vk{dispatch_type}Dispatch* dispatch = tables::Lookup{dispatch_type}Dispatch({procaddr_name});\n")
        for func in self.registry.funcs.values():
            if not func.is_required():
                continue
            if func.get_func_type() == dispatch_type:
                self.print_object_platform_ifdef(f, func)
                func_name_normalized = remove_vk_prefix(func.name)
                if is_proc_addr_func(func.name):
                    f.write(f"    if (!std::strcmp(\"{func.name}\", name))\n")
                    f.write(f"      return (PFN_vkVoidFunction) &{func_name_normalized}<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
                else:
                    f.write(f"    constexpr bool Has{func_name_normalized} = requires(const {dispatch_type}Overrides& t) {{ &{dispatch_type}Overrides::{func_name_normalized}; }};\n")
                    f.write(f"    if constexpr (Has{func_name_normalized}) {{\n")
                    # VS is smart enough to make stateless lambdas with the right calling conventions.
                    # if you simply just cast them to the right function pointer type!
                    # Versions which are unused are elimated by the linker.
                    # TODO: Is this enough for MinGW?
                    f.write(f"      if (!std::strcmp(\"{func.name}\", name))\n")
                    f.write(f"        return (PFN_vkVoidFunction) &wrap_{func_name_normalized}<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
                    f.write( "    }\n")
                    if func.name == "vkCreateInstance":
                        f.write( "    else {\n")
                        f.write(f"      if (!std::strcmp(\"{func.name}\", name))\n")
                        f.write(f"        return (PFN_vkVoidFunction) &implicit_wrap_CreateInstance<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
                        f.write( "    }\n")
                    if func.name == "vkCreateDevice":
                        f.write( "    else {\n")
                        # If we don't have any overrides, don't hook this to make a dispatch table.
                        f.write(f"      if (!std::is_base_of<NoOverrides, DeviceOverrides>::value && !std::strcmp(\"{func.name}\", name))\n")
                        f.write(f"        return (PFN_vkVoidFunction) &implicit_wrap_CreateDevice<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
                        f.write( "    }\n")
                    if func.name == "vkDestroyInstance":
                        f.write( "    else {\n")
                        f.write(f"      if (!std::strcmp(\"{func.name}\", name))\n")
                        f.write(f"        return (PFN_vkVoidFunction) &implicit_wrap_DestroyInstance<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
                        f.write( "    }\n")
                    if func.name == "vkDestroyDevice":
                        f.write( "    else {\n")
                        f.write(f"      if (!std::strcmp(\"{func.name}\", name))\n")
                        f.write(f"        return (PFN_vkVoidFunction) &implicit_wrap_DestroyDevice<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
                        f.write( "    }\n")
                self.print_object_platform_endif(f, func);
                f.write( "\n")
        # WHY DOES THIS EXIST????
        # PLEASE TELL ME.
        # I ALREADY GAVE YOU THIS IN vkNegotiateLoaderLayerInterfaceVersion!!! :(
        if dispatch_type == "PhysicalDevice":
            f.write( "    if constexpr (!std::is_base_of<NoOverrides, PhysicalDeviceOverrides>::value || !std::is_base_of<NoOverrides, DeviceOverrides>::value) {\n")
            f.write(f"      if (!std::strcmp(\"vk_layerGetPhysicalDeviceProcAddr\", name))\n")
            f.write(f"        return (PFN_vkVoidFunction) &GetPhysicalDeviceProcAddr<InstanceOverrides, PhysicalDeviceOverrides, DeviceOverrides>;\n")
            f.write( "    }\n")
            f.write(f"\n")
        f.write(f"    if (dispatch)\n")
        f.write(f"      return dispatch->{procaddr_normalized}({procaddr_name}, name);\n")
        f.write(f"    else\n")
        f.write(f"      return NULL;\n")
        f.write( "  }\n\n")

    def write_dispatch_classes(self, f):
        self.write_dispatch_class(f, "Instance", "instance",             "Instance", "instance")
        self.write_dispatch_class(f, "PhysicalDevice", "physicalDevice", "Instance", "instance")
        self.write_dispatch_class(f, "Device", "device",                 "Device",   "device")

        self.write_dispatch_funcs(f, "Instance", "instance",             "Instance", "instance")
        self.write_dispatch_funcs(f, "PhysicalDevice", "physicalDevice", "Instance", "instance")
        self.write_dispatch_funcs(f, "Device", "device",                 "Device",   "device")

        self.write_dispatch_impls(f, "Instance", "instance",             "Instance", "instance")
        self.write_dispatch_impls(f, "PhysicalDevice", "physicalDevice", "Instance", "instance")
        self.write_dispatch_impls(f, "Device", "device",                 "Device",   "device")

    def write_enum_string_helpers(self, f):
        f.write( "  namespace helpers {\n")
        f.write(f"    template <typename EnumType>\n")
        f.write(f"    constexpr const char* enumString(EnumType type);\n")
        for enum in self.registry.enums.values():
            if not enum.required:
                continue
            if not enum.is_alias() and enum.bitwidth == 32 and len(enum.values) > 1: # 1 for MAX_ENUM crap
                f.write(f"\n")
                self.print_object_platform_ifdef(f, enum)
                f.write(f"    template <> constexpr const char* enumString<{enum.name}>({enum.name} type) {{\n")
                f.write(f"      switch(static_cast<uint64_t>(type)) {{\n")
                for enum_value in enum.values:
                    if not enum_value.is_alias():
                        f.write(f"        case static_cast<uint64_t>({enum_value.value}): return \"{enum_value.name}\";\n")
                f.write(f"        default: return \"{enum.name}_UNKNOWN\";\n")
                f.write(f"      }}\n")
                f.write(f"    }}\n")
                self.print_object_platform_endif(f, enum)
        f.write( "  }\n")

    def write_stype_helpers(self, f):
        f.write(f"  template <typename Type>\n")
        f.write(f"  constexpr VkStructureType ResolveSType();\n")
        for struct in self.registry.structs:
            if not struct.required:
                continue

            if struct.alias:
                continue

            for member in struct.members:
                if member.name == 'sType' and member.values != None and member.values != "":
                    f.write(f"\n")
                    self.print_object_platform_ifdef(f, struct)
                    f.write(f"  template <> constexpr VkStructureType ResolveSType<{struct.name}>() {{ return {member.values}; }}\n")
                    f.write(f"  template <> constexpr VkStructureType ResolveSType<const {struct.name}>() {{ return {member.values}; }}\n")
                    self.print_object_platform_endif(f, struct)

def download_vk_xml(filename):
    url = "https://raw.githubusercontent.com/KhronosGroup/Vulkan-Docs/v{0}/xml/vk.xml".format(VK_XML_VERSION)
    if not os.path.isfile(filename):
        urllib.request.urlretrieve(url, filename)

def set_working_directory():
    path = os.path.abspath(__file__)
    path = os.path.dirname(path)
    os.chdir(path)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-v", "--verbose", action="count", default=0, help="increase output verbosity")
    parser.add_argument("-x", "--xml", default=None, type=str, help="path to specification XML file")

    args = parser.parse_args()
    if args.verbose == 0:
        LOGGER.setLevel(logging.WARNING)
    elif args.verbose == 1:
        LOGGER.setLevel(logging.INFO)
    else: # > 1
        LOGGER.setLevel(logging.DEBUG)

    set_working_directory()

    if args.xml:
        vk_xml = args.xml
    else:
        vk_xml = "vk-{0}.xml".format(VK_XML_VERSION)
        download_vk_xml(vk_xml)

    registry = VkRegistry(vk_xml)
    generator = VkGenerator(registry)

    with open("../vkroots.h", "w") as f:
        write_include(f, "vkroots_includes.h")
        write_include(f, "vkroots_forwarders.h")
        f.write( "namespace vkroots {\n")
        generator.write_dispatch_classes(f)
        generator.write_enum_string_helpers(f)
        generator.write_stype_helpers(f)
        f.write( "}\n")
        f.write( "\n")
        write_include(f, "vkroots_dispatches.h")
        write_include(f, "vkroots_helpers.h")
        # Implementations
        write_include(f, "vkroots_loader_layer_interface.h")

if __name__ == "__main__":
    main()