﻿/****************************************************************/
/*

    ██    ██ ██    ██ ██      ██   ██  █████  ███    ██
    ██    ██ ██    ██ ██      ██  ██  ██   ██ ████   ██
    ██    ██ ██    ██ ██      █████   ███████ ██ ██  ██
     ██  ██  ██    ██ ██      ██  ██  ██   ██ ██  ██ ██
      ████    ██████  ███████ ██   ██ ██   ██ ██   ████

    Vulkan Programming (by Kenwright)

    
    https://xbdev.net/vulkan/
	https://vulkanlab.xbdev.net/
/*
/****************************************************************/

// Windows version of the Vulkan API
//#define VK_USE_PLATFORM_WIN32_KHR


#define GLFW_INCLUDE_VULKAN
#include <GLFW/glfw3.h>
#pragma comment(lib, "glfw3.lib")
#include <iostream>

//#define VK_ENABLE_BETA_EXTENSIONS

#define VK_NO_PROTOTYPES
#include <vulkan/vulkan.h>

#pragma comment(lib, "vulkan-1.lib") 

// Do we want to enable the added Vuklan debug?
#define ENABLE_VULKAN_DEBUG_CALLBACK

//--------------------------------------------------------------//
//--------------------------------------------------------------//

// For variable argument functions ,e.g., dprintf(..), 
#include <stdlib.h>
#include <stdarg.h>
#include <stdio.h>  // vsprintf_s

//Saving debug information to a log file/screen/..
inline 
void dprintf(const char *fmt, ...) 
{
	va_list parms;
	static char buf[2048] = {0};

	// Try to print in the allocated space.
	va_start(parms, fmt);
	vsprintf_s(buf, fmt, parms);
	va_end(parms);

	// Write the information out to a txt file
	#if 0
	FILE *fp = fopen("output.txt", "a+");
	fprintf(fp, "%s",  buf);
	fclose(fp);
	#endif

	// Output to the visual studio window
	//OutputDebugStringA( buf );
    printf( buf );

}// End dprintf(..)


// Debug defines (custom asserts)
#if defined(_WIN32)
  #define DBG_ASSERT(f) { if(!(f)){ __debugbreak(); }; }
#else
  #define DBG_ASSERT(f) { #error(platform assert todo) }
#endif

//--------------------------------------------------------------//
//--------------------------------------------------------------//

    
#ifdef ENABLE_VULKAN_DEBUG_CALLBACK // Debug callback
// Set this function as a debug callback when we initialize
// Vulkan to let us know if something went wrong
VKAPI_ATTR VkBool32 VKAPI_CALL 
  MyDebugReportCallback( VkDebugReportFlagsEXT flags, 
               VkDebugReportObjectTypeEXT objectType, 
               uint64_t object,
               size_t location, 
               int32_t messageCode, 
               const char* pLayerPrefix, 
               const char* pMessage, 
               void* pUserData ) 
{
    
    dprintf( pLayerPrefix );
    dprintf( " " );
    dprintf( pMessage );
    dprintf( "\n" );
    DBG_ASSERT(false);
    
    return VK_FALSE;
}// End MyDebugReportCallback(..)
#endif


//--------------------------------------------------------------//
//--------------------------------------------------------------//

// Specific for the **ray tracing** - grab them from the device (must remember to 'enable' them when 
// setting up - otherwise they'll fail
PFN_vkGetDeviceQueue                             vkGetDeviceQueueA;
PFN_vkGetBufferDeviceAddressKHR                  vkGetBufferDeviceAddressKHRA;
PFN_vkCreateAccelerationStructureKHR             vkCreateAccelerationStructureKHRA;
PFN_vkDestroyAccelerationStructureKHR            vkDestroyAccelerationStructureKHRA;
PFN_vkGetAccelerationStructureBuildSizesKHR      vkGetAccelerationStructureBuildSizesKHRA;
PFN_vkGetAccelerationStructureDeviceAddressKHR   vkGetAccelerationStructureDeviceAddressKHRA;
PFN_vkCmdBuildAccelerationStructuresKHR          vkCmdBuildAccelerationStructuresKHRA;
PFN_vkBuildAccelerationStructuresKHR             vkBuildAccelerationStructuresKHRA;
PFN_vkCmdTraceRaysKHR                            vkCmdTraceRaysKHRA;
PFN_vkGetRayTracingShaderGroupHandlesKHR         vkGetRayTracingShaderGroupHandlesKHRA;
PFN_vkCreateRayTracingPipelinesKHR               vkCreateRayTracingPipelinesKHRA;

//--------------------------------------------------------------//
//--------------------------------------------------------------//


#include <set>
#include <vector>
#include <array>
#include <cmath>
#include <functional>
#include <fstream>
#include <iostream>

// Define a type for a vertex (position, normal, color)
struct Vertex {
    float position[3]; // x, y, z
    float normal[3];   // nx, ny, nz
    float color[3];    // r, g, b
};

// Vulkan Structures for Ray Tracing
struct VulkanInstance {
    VkInstance instance = VK_NULL_HANDLE;
    VkPhysicalDevice physicalDevice = VK_NULL_HANDLE;
    VkDevice device = VK_NULL_HANDLE;
    VkQueue graphicsQueue = VK_NULL_HANDLE;
    VkQueue computeQueue = VK_NULL_HANDLE;
    VkSurfaceKHR surface = VK_NULL_HANDLE;

    // Structures (filled by the ray-tracing methods)
    VkPhysicalDeviceRayTracingPipelinePropertiesKHR  rayTracingPipelineProperties{};
    VkPhysicalDeviceAccelerationStructureFeaturesKHR accelerationStructureFeatures{};
};

struct VulkanSwapchain {
    VkSwapchainKHR swapchain = VK_NULL_HANDLE;
    VkFormat imageFormat;
    VkExtent2D extent;
    std::vector<VkImage> images;
    std::vector<VkImageView> imageViews;
};

struct VulkanAccelerationStructure {
    uint64_t deviceAddress = 0;
    VkAccelerationStructureKHR handle = VK_NULL_HANDLE;
    VkBuffer buffer = VK_NULL_HANDLE;
    VkDeviceMemory memory = VK_NULL_HANDLE;
};

struct VulkanDescriptors {
    VkImage storageImage;
    VkImageView storageImageView;
    VkDeviceMemory storageImageMemory;

    VkDescriptorSetLayout descriptorSetLayout = VK_NULL_HANDLE;
    VkDescriptorPool descriptorPool = VK_NULL_HANDLE;
    VkDescriptorSet descriptorSet = VK_NULL_HANDLE;
};

struct VulkanRayTracingPipeline {
    VkPipelineLayout layout = VK_NULL_HANDLE;
    VkPipeline pipeline = VK_NULL_HANDLE;
};

struct VulkanShaderBindingTable {
    std::vector<VkRayTracingShaderGroupCreateInfoKHR>   shaderGroups{};

    VkBuffer buffer = VK_NULL_HANDLE;
    VkDeviceMemory memory = VK_NULL_HANDLE;
    VkStridedDeviceAddressRegionKHR raygenRegion{};
    VkStridedDeviceAddressRegionKHR missRegion{};
    VkStridedDeviceAddressRegionKHR hitRegion{};
};

struct VulkanCommandBuffers {
    VkCommandPool commandPool = VK_NULL_HANDLE;
    std::vector<VkCommandBuffer> commandBuffers;
};

struct UniformBufferObject {
    float time;
};

VulkanInstance vkInst{};
VulkanSwapchain swapchain{};
VulkanAccelerationStructure blas{};
VulkanAccelerationStructure tlas{};
VulkanDescriptors descriptors{};
VulkanRayTracingPipeline rtPipeline{};
VulkanShaderBindingTable sbt{};
VulkanCommandBuffers commandBuffers{};

//--------------------------------------------------------------//
//--------------------------------------------------------------//


int findQueueFamilyIndex(VkPhysicalDevice device, VkQueueFlags requiredFlags) {
    uint32_t queueFamilyCount = 0;
    vkGetPhysicalDeviceQueueFamilyProperties(device, &queueFamilyCount, nullptr);
    std::vector<VkQueueFamilyProperties> queueFamilies(queueFamilyCount);
    vkGetPhysicalDeviceQueueFamilyProperties(device, &queueFamilyCount, queueFamilies.data());

    for (int i = 0; i < (int)queueFamilyCount; ++i) {
        if (queueFamilies[i].queueFlags & requiredFlags) {
            return i;  // Found a queue family that supports required flags
        }
    }

    return -1;  // No suitable queue family found
}

//--------------------------------------------------------------//
//--------------------------------------------------------------//

void createVulkanInstance(GLFWwindow* window) {
    if (vkInst.instance != VK_NULL_HANDLE) return;

    // Step 1: Create Vulkan Instance
    VkApplicationInfo appInfo{};
    appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
    appInfo.pApplicationName = "Vulkan Ray Tracing Example";
    appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
    appInfo.pEngineName = "No Engine";
    appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0);
    appInfo.apiVersion = VK_API_VERSION_1_2;

    VkInstanceCreateInfo createInfo{};
    createInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
    createInfo.pApplicationInfo = &appInfo;

    // Enable instance extensions required by GLFW for surface creation
    //uint32_t glfwExtensionCount = 0;
    //const char** glfwExtensions = glfwGetRequiredInstanceExtensions(&glfwExtensionCount);
    //createInfo.enabledExtensionCount = glfwExtensionCount;
    //createInfo.ppEnabledExtensionNames = glfwExtensions;

    const char* extensions[] = { "VK_KHR_surface",
                              "VK_KHR_win32_surface",
                              "VK_EXT_debug_report" };
    createInfo.enabledExtensionCount = 3;
    createInfo.ppEnabledExtensionNames = extensions;


    // Enable validation layers (optional but useful for debugging)
    const char* layers[] = { "VK_LAYER_KHRONOS_validation" };
    createInfo.enabledLayerCount = 1;
    createInfo.ppEnabledLayerNames = layers;

    if (vkCreateInstance(&createInfo, nullptr, &vkInst.instance) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create Vulkan instance!");
    }

    // Optional: Setup debug callback if enabled
#ifdef ENABLE_VULKAN_DEBUG_CALLBACK
    {
    // Register our error logging function (defined at the top of the file)
    VkDebugReportCallbackEXT error_callback   = VK_NULL_HANDLE;
    VkDebugReportCallbackEXT warning_callback = VK_NULL_HANDLE;

    PFN_vkCreateDebugReportCallbackEXT vkCreateDebugReportCallbackEXT = NULL;

    *(void **)&vkCreateDebugReportCallbackEXT = 
      vkGetInstanceProcAddr( vkInst.instance, "vkCreateDebugReportCallbackEXT" );
    DBG_ASSERT(vkCreateDebugReportCallbackEXT);


    VkDebugReportCallbackCreateInfoEXT cb_create_info = {};
    cb_create_info.sType       = VK_STRUCTURE_TYPE_DEBUG_REPORT_CREATE_INFO_EXT;
    cb_create_info.flags       = VK_DEBUG_REPORT_ERROR_BIT_EXT;
    cb_create_info.pfnCallback = &MyDebugReportCallback;

    VkResult result = 
    vkCreateDebugReportCallbackEXT(vkInst.instance, &cb_create_info,
                                     nullptr, &error_callback);
    DBG_ASSERT(result==VK_SUCCESS); // "vkCreateDebugReportCallbackEXT(ERROR) failed"

    // Capture warning as well as errors
    cb_create_info.flags = VK_DEBUG_REPORT_WARNING_BIT_EXT |
                            VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT;
    cb_create_info.pfnCallback = &MyDebugReportCallback;

    result = 
    vkCreateDebugReportCallbackEXT(vkInst.instance, 
                                    &cb_create_info,
                                    nullptr, 
                                    &warning_callback);
    DBG_ASSERT(result==VK_SUCCESS); // "vkCreateDebugReportCallbackEXT(WARN) failed"
    }
#endif

    // Step 2: Create Vulkan Surface
    if (glfwCreateWindowSurface(vkInst.instance, window, nullptr, &vkInst.surface) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create window surface!");
    }

    // Step 3: Select a Physical Device (GPU)
    uint32_t deviceCount = 0;
    vkEnumeratePhysicalDevices(vkInst.instance, &deviceCount, nullptr);
    if (deviceCount == 0) throw std::runtime_error("No Vulkan-compatible GPUs found!");

    std::vector<VkPhysicalDevice> devices(deviceCount);
    vkEnumeratePhysicalDevices(vkInst.instance, &deviceCount, devices.data());
    vkInst.physicalDevice = devices[0]; // Pick the first GPU (simplified)

    // Step 4: Check for required device features (Ray Tracing, Buffer Device Address, etc.)
    // Ray tracing related extensions
    VkPhysicalDeviceBufferDeviceAddressFeatures
        buffer_device_address_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES };
    buffer_device_address_features.bufferDeviceAddress = VK_TRUE;

    VkPhysicalDeviceDynamicRenderingFeatures
        dynamic_rendering_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DYNAMIC_RENDERING_FEATURES };
    dynamic_rendering_features.dynamicRendering = VK_TRUE;

    VkPhysicalDeviceSynchronization2Features
        synchronization2_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SYNCHRONIZATION_2_FEATURES };
    synchronization2_features.synchronization2 = VK_TRUE;

    VkPhysicalDeviceDescriptorIndexingFeatures
        descriptor_indexing_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DESCRIPTOR_INDEXING_FEATURES };
    descriptor_indexing_features.runtimeDescriptorArray = VK_TRUE;

    VkPhysicalDeviceMaintenance4Features
        maintenance4_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES };
    maintenance4_features.maintenance4 = VK_TRUE;

    VkPhysicalDeviceAccelerationStructureFeaturesKHR
        acceleration_structure_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ACCELERATION_STRUCTURE_FEATURES_KHR };
    acceleration_structure_features.accelerationStructure = VK_TRUE;

    VkPhysicalDeviceRayTracingPipelineFeaturesKHR
        ray_tracing_pipeline_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PIPELINE_FEATURES_KHR };
    ray_tracing_pipeline_features.rayTracingPipeline = VK_TRUE;

    VkPhysicalDeviceRobustness2FeaturesEXT
        robustness2_features{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ROBUSTNESS_2_FEATURES_EXT };
    robustness2_features.nullDescriptor = VK_TRUE;

    // Add features to the chain
    buffer_device_address_features.pNext = &dynamic_rendering_features;
    dynamic_rendering_features.pNext = &synchronization2_features;
    synchronization2_features.pNext = &descriptor_indexing_features;
    descriptor_indexing_features.pNext = &maintenance4_features;
    maintenance4_features.pNext = &acceleration_structure_features;
    acceleration_structure_features.pNext = &ray_tracing_pipeline_features;
    ray_tracing_pipeline_features.pNext = &robustness2_features;

    
    VkPhysicalDeviceFeatures2 features2{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2 };
    features2.pNext = &buffer_device_address_features;

    // Step 5: Create Logical Device with Required Extensions and Features
    VkDeviceQueueCreateInfo queueCreateInfo{};
    queueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
    queueCreateInfo.queueFamilyIndex = 0; // Choose an appropriate queue family (0 for simplicity)
    queueCreateInfo.queueCount = 1;
    float queuePriority = 1.0f;
    queueCreateInfo.pQueuePriorities = &queuePriority;

    // List of device extensions required for Ray Tracing and other features
    std::vector<const char*> enabledDeviceExtensions = {
        VK_KHR_SWAPCHAIN_EXTENSION_NAME,
        VK_EXT_ROBUSTNESS_2_EXTENSION_NAME,
        VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME,
        VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME,
        VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME,
        VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME,
        VK_EXT_DESCRIPTOR_INDEXING_EXTENSION_NAME,
        VK_KHR_SPIRV_1_4_EXTENSION_NAME,
        VK_KHR_SHADER_FLOAT_CONTROLS_EXTENSION_NAME,
        VK_KHR_DYNAMIC_RENDERING_EXTENSION_NAME,
        VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME,
        VK_KHR_MAINTENANCE_4_EXTENSION_NAME
    };

    // Create logical device
    VkDeviceCreateInfo deviceCreateInfo{};
    deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
    deviceCreateInfo.queueCreateInfoCount = 1;
    deviceCreateInfo.pQueueCreateInfos = &queueCreateInfo;
    deviceCreateInfo.enabledExtensionCount = static_cast<uint32_t>(enabledDeviceExtensions.size());
    deviceCreateInfo.ppEnabledExtensionNames = enabledDeviceExtensions.data();
    deviceCreateInfo.enabledLayerCount = 1;
    deviceCreateInfo.ppEnabledLayerNames = layers;
    deviceCreateInfo.pNext = &features2; // Enable all required device features

    if (vkCreateDevice(vkInst.physicalDevice, &deviceCreateInfo, nullptr, &vkInst.device) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create logical device!");
    }

    // Step 6: Retrieve Graphics Queue
    vkGetDeviceQueue(vkInst.device, 0, 0, &vkInst.graphicsQueue);

    // Optional: Retrieve Compute Queue (if needed for specific tasks)
    vkGetDeviceQueue(vkInst.device, 0, 0, &vkInst.computeQueue);


    // Get ray tracing pipeline properties, which will be used later on in the sample
    vkInst.rayTracingPipelineProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PIPELINE_PROPERTIES_KHR;
    VkPhysicalDeviceProperties2 deviceProperties2{};
    deviceProperties2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
    deviceProperties2.pNext = &vkInst.rayTracingPipelineProperties;
    vkGetPhysicalDeviceProperties2(vkInst.physicalDevice, &deviceProperties2);

    // Get acceleration structure properties, which will be used later on in the sample
    vkInst.accelerationStructureFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ACCELERATION_STRUCTURE_FEATURES_KHR;
    VkPhysicalDeviceFeatures2 deviceFeatures2{};
    deviceFeatures2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
    deviceFeatures2.pNext = &vkInst.accelerationStructureFeatures;
    vkGetPhysicalDeviceFeatures2(vkInst.physicalDevice, &deviceFeatures2);

    // Ray-tracing dynamic functions
    // Get the *****ray tracing and accelertion structure***** related function pointers required by this sample  
    vkGetDeviceQueueA = reinterpret_cast<PFN_vkGetDeviceQueue>(vkGetDeviceProcAddr(vkInst.device, "vkGetDeviceQueue"));
    vkGetBufferDeviceAddressKHRA = reinterpret_cast<PFN_vkGetBufferDeviceAddressKHR>(vkGetDeviceProcAddr(vkInst.device, "vkGetBufferDeviceAddressKHR"));
    vkCmdBuildAccelerationStructuresKHRA = reinterpret_cast<PFN_vkCmdBuildAccelerationStructuresKHR>(vkGetDeviceProcAddr(vkInst.device, "vkCmdBuildAccelerationStructuresKHR"));
    vkBuildAccelerationStructuresKHRA = reinterpret_cast<PFN_vkBuildAccelerationStructuresKHR>(vkGetDeviceProcAddr(vkInst.device, "vkBuildAccelerationStructuresKHR"));
    vkCreateAccelerationStructureKHRA = reinterpret_cast<PFN_vkCreateAccelerationStructureKHR>(vkGetDeviceProcAddr(vkInst.device, "vkCreateAccelerationStructureKHR"));
    vkDestroyAccelerationStructureKHRA = reinterpret_cast<PFN_vkDestroyAccelerationStructureKHR>(vkGetDeviceProcAddr(vkInst.device, "vkDestroyAccelerationStructureKHR"));
    vkGetAccelerationStructureBuildSizesKHRA = reinterpret_cast<PFN_vkGetAccelerationStructureBuildSizesKHR>(vkGetDeviceProcAddr(vkInst.device, "vkGetAccelerationStructureBuildSizesKHR"));
    vkGetAccelerationStructureDeviceAddressKHRA = reinterpret_cast<PFN_vkGetAccelerationStructureDeviceAddressKHR>(vkGetDeviceProcAddr(vkInst.device, "vkGetAccelerationStructureDeviceAddressKHR"));
    vkCmdTraceRaysKHRA = reinterpret_cast<PFN_vkCmdTraceRaysKHR>(vkGetDeviceProcAddr(vkInst.device, "vkCmdTraceRaysKHR"));
    vkGetRayTracingShaderGroupHandlesKHRA = reinterpret_cast<PFN_vkGetRayTracingShaderGroupHandlesKHR>(vkGetDeviceProcAddr(vkInst.device, "vkGetRayTracingShaderGroupHandlesKHR"));
    vkCreateRayTracingPipelinesKHRA = reinterpret_cast<PFN_vkCreateRayTracingPipelinesKHR>(vkGetDeviceProcAddr(vkInst.device, "vkCreateRayTracingPipelinesKHR"));
}




void createSwapchain(GLFWwindow* window) {
    // Step 1: Query available surface formats
    uint32_t formatCount = 0;
    vkGetPhysicalDeviceSurfaceFormatsKHR(vkInst.physicalDevice, vkInst.surface, &formatCount, nullptr);
    if (formatCount == 0) {
        throw std::runtime_error("Failed to find any surface formats!");
    }

    std::vector<VkSurfaceFormatKHR> availableFormats(formatCount);
    vkGetPhysicalDeviceSurfaceFormatsKHR(vkInst.physicalDevice, vkInst.surface, &formatCount, availableFormats.data());

    // VK_FORMAT_B8G8R8A8_UNORM
    // VK_FORMAT_B8G8R8A8_SRGB

    VkSurfaceFormatKHR surfaceFormat = availableFormats[0];
    for (const auto& availableFormat : availableFormats) {
        // Prefer a format with SRGB colors (VK_FORMAT_B8G8R8A8_SRGB is common for many systems)
        if (availableFormat.format == VK_FORMAT_B8G8R8A8_UNORM ) // && availableFormat.colorSpace == VK_COLOR_SPACE_SRGB_NONLINEAR_KHR) {
        {
            surfaceFormat = availableFormat;
            break;
        }
    }

   // surfaceFormat.format = VK_FORMAT_B8G8R8A8_UNORM;

    // Step 2: Query presentation modes
    uint32_t presentModeCount = 0;
    vkGetPhysicalDeviceSurfacePresentModesKHR(vkInst.physicalDevice, vkInst.surface, &presentModeCount, nullptr);
    if (presentModeCount == 0) {
        throw std::runtime_error("Failed to find any present modes!");
    }

    std::vector<VkPresentModeKHR> availablePresentModes(presentModeCount);
    vkGetPhysicalDeviceSurfacePresentModesKHR(vkInst.physicalDevice, vkInst.surface, &presentModeCount, availablePresentModes.data());

    VkPresentModeKHR presentMode = VK_PRESENT_MODE_FIFO_KHR; // FIFO is supported everywhere
    for (const auto& availablePresentMode : availablePresentModes) {
        if (availablePresentMode == VK_PRESENT_MODE_MAILBOX_KHR) {
            presentMode = availablePresentMode; // Prefer mailbox if available (low latency)
            break;
        }
    }

    // Step 3: Query swap chain extent (size of the swapchain images)
    VkSurfaceCapabilitiesKHR capabilities;
    vkGetPhysicalDeviceSurfaceCapabilitiesKHR(vkInst.physicalDevice, vkInst.surface, &capabilities);

    VkExtent2D swapchainExtent = capabilities.currentExtent;
    if (swapchainExtent.width == std::numeric_limits<uint32_t>::max()) {
        // If the width or height is set to max value, it means the extent is dynamic, so we choose the window size
        int width, height;
        glfwGetFramebufferSize(window, &width, &height);
        swapchainExtent.width = static_cast<uint32_t>(width);
        swapchainExtent.height = static_cast<uint32_t>(height);
    }

    // Step 4: Determine the number of images in the swapchain
    uint32_t imageCount = capabilities.minImageCount + 1;
    if (capabilities.maxImageCount > 0 && imageCount > capabilities.maxImageCount) {
        imageCount = capabilities.maxImageCount;
    }

    // Step 5: Create the swapchain
    VkSwapchainCreateInfoKHR createInfo{};
    createInfo.sType = VK_STRUCTURE_TYPE_SWAPCHAIN_CREATE_INFO_KHR;
    createInfo.surface = vkInst.surface;
    createInfo.minImageCount = imageCount;
    createInfo.imageFormat = surfaceFormat.format;
    createInfo.imageColorSpace = surfaceFormat.colorSpace;
    createInfo.imageExtent = swapchainExtent;
    createInfo.imageArrayLayers = 1;
    createInfo.imageUsage = VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT;

    // Choose the queue family index for image sharing
    uint32_t queueFamilyIndices[] = { 0 }; // Assume we are using the first queue family
    createInfo.imageSharingMode = VK_SHARING_MODE_EXCLUSIVE;
    createInfo.queueFamilyIndexCount = 0; // Ignored for exclusive mode
    createInfo.pQueueFamilyIndices = nullptr;

    createInfo.preTransform = capabilities.currentTransform;
    createInfo.compositeAlpha = VK_COMPOSITE_ALPHA_OPAQUE_BIT_KHR;
    createInfo.presentMode = presentMode;
    createInfo.clipped = VK_TRUE;
    createInfo.oldSwapchain = VK_NULL_HANDLE;

    createInfo.imageUsage |= VK_IMAGE_USAGE_TRANSFER_SRC_BIT;
    createInfo.imageUsage |= VK_IMAGE_USAGE_TRANSFER_DST_BIT;

    if (vkCreateSwapchainKHR(vkInst.device, &createInfo, nullptr, &swapchain.swapchain) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create swapchain!");
    }

    // Step 6: Get swapchain images
    vkGetSwapchainImagesKHR(vkInst.device, swapchain.swapchain, &imageCount, nullptr);
    swapchain.images.resize(imageCount);
    vkGetSwapchainImagesKHR(vkInst.device, swapchain.swapchain, &imageCount, swapchain.images.data());

    swapchain.imageFormat = surfaceFormat.format;
    swapchain.extent = swapchainExtent;

    //--------

    uint32_t queueFamilyIndex = 0;

    VkCommandPoolCreateInfo poolInfo{};
    poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
    poolInfo.queueFamilyIndex = queueFamilyIndex;
    poolInfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;  // Allow command buffers to be individually reset

    if (vkCreateCommandPool(vkInst.device, &poolInfo, nullptr, &commandBuffers.commandPool) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create command pool!");
    }


    // Allocate command buffers
    commandBuffers.commandBuffers.resize(swapchain.images.size());

}

uint32_t findMemoryType(uint32_t typeFilter, VkMemoryPropertyFlags properties) {
    VkPhysicalDeviceMemoryProperties memProperties;
    vkGetPhysicalDeviceMemoryProperties(vkInst.physicalDevice, &memProperties);

    for (uint32_t i = 0; i < memProperties.memoryTypeCount; i++) {
        if ((typeFilter & (1 << i)) && (memProperties.memoryTypes[i].propertyFlags & properties) == properties) {
            return i;
        }
    }

    throw std::runtime_error("Failed to find suitable memory type!");
}

VkDeviceAddress getBufferDeviceAddress(VkBuffer buffer) {
    VkBufferDeviceAddressInfoKHR bufferInfo{};
    bufferInfo.sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO_KHR;
    bufferInfo.buffer = buffer;

    return vkGetBufferDeviceAddressKHRA(vkInst.device, &bufferInfo);
}


void createBuffer(VkDeviceSize size, VkBufferUsageFlags usage, VkMemoryPropertyFlags properties,
    VkBuffer& buffer, VkDeviceMemory& bufferMemory, void* dataPtr = nullptr) {

    // Step 1: Create Buffer
    VkBufferCreateInfo bufferInfo{};
    bufferInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
    bufferInfo.size = size;
    bufferInfo.usage = usage;

    //if (dataPtr)
    //{
    //    bufferInfo.usage |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
    //}
    bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;

    if (vkCreateBuffer(vkInst.device, &bufferInfo, nullptr, &buffer) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create buffer!");
    }

    // Step 2: Get Memory Requirements
    VkMemoryRequirements memRequirements;
    vkGetBufferMemoryRequirements(vkInst.device, buffer, &memRequirements);

    // Step 3: Allocate Memory with Device Address Flag (if required)
    VkMemoryAllocateFlagsInfo allocateFlagsInfo{};
    allocateFlagsInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO;
    allocateFlagsInfo.flags = (usage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT) ?
        VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT : 0;

    VkMemoryAllocateInfo allocInfo{};
    allocInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
    allocInfo.allocationSize = memRequirements.size;
    allocInfo.memoryTypeIndex = findMemoryType(memRequirements.memoryTypeBits, properties);

    // Attach allocation flags if required
    if (allocateFlagsInfo.flags) {
        allocInfo.pNext = &allocateFlagsInfo;
    }

    if (vkAllocateMemory(vkInst.device, &allocInfo, nullptr, &bufferMemory) != VK_SUCCESS) {
        throw std::runtime_error("Failed to allocate buffer memory!");
    }

    // Step 4: Bind Buffer with Memory
    vkBindBufferMemory(vkInst.device, buffer, bufferMemory, 0);

    // Step 5: Copy Data If Provided
    if (dataPtr) {
        void* mappedMemory;
        vkMapMemory(vkInst.device, bufferMemory, 0, size, 0, &mappedMemory);
        memcpy(mappedMemory, dataPtr, static_cast<size_t>(size));
        vkUnmapMemory(vkInst.device, bufferMemory);
    }
}



VkCommandBuffer beginSingleTimeCommands() {
    VkCommandBufferAllocateInfo allocInfo{};
    allocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    allocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    allocInfo.commandPool = commandBuffers.commandPool;  // Assume `commandPool` is created beforehand
    allocInfo.commandBufferCount = 1;

    VkCommandBuffer commandBuffer;
    vkAllocateCommandBuffers(vkInst.device, &allocInfo, &commandBuffer);

    VkCommandBufferBeginInfo beginInfo{};
    beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;

    vkBeginCommandBuffer(commandBuffer, &beginInfo);

    return commandBuffer;
}

void endSingleTimeCommands(VkCommandBuffer commandBuffer) {
    vkEndCommandBuffer(commandBuffer);

    VkSubmitInfo submitInfo{};
    submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
    submitInfo.commandBufferCount = 1;
    submitInfo.pCommandBuffers = &commandBuffer;

    vkQueueSubmit(vkInst.graphicsQueue, 1, &submitInfo, VK_NULL_HANDLE);
    vkQueueWaitIdle(vkInst.graphicsQueue);

    vkFreeCommandBuffers(vkInst.device, commandBuffers.commandPool, 1, &commandBuffer);
}



void createBottomLevelAS(const std::vector<Vertex>& vertices, const std::vector<uint16_t>& indices) {
    VkDeviceSize vertexBufferSize = sizeof(vertices[0]) * vertices.size();
    VkDeviceSize indexBufferSize = sizeof(indices[0]) * indices.size();

    // 1. Create Vertex Buffer (Device Local)
    VkBuffer vertexBuffer;
    VkDeviceMemory vertexMemory;
    createBuffer(vertexBufferSize,
        VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR,
        VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
        vertexBuffer, vertexMemory, (void*)vertices.data());

    // 2. Create Index Buffer (Device Local)
    VkBuffer indexBuffer;
    VkDeviceMemory indexMemory;
    createBuffer(indexBufferSize,
        VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR,
        VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
        indexBuffer, indexMemory, (void*)indices.data());

    // 3. Define Geometry for BLAS
    VkAccelerationStructureGeometryKHR asGeometry{};
    asGeometry.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR;
    asGeometry.geometryType = VK_GEOMETRY_TYPE_TRIANGLES_KHR;
    asGeometry.flags = VK_GEOMETRY_OPAQUE_BIT_KHR;
    asGeometry.geometry.triangles.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_TRIANGLES_DATA_KHR;
    asGeometry.geometry.triangles.vertexFormat = VK_FORMAT_R32G32B32_SFLOAT;
    asGeometry.geometry.triangles.vertexData.deviceAddress = getBufferDeviceAddress(vertexBuffer);
    asGeometry.geometry.triangles.vertexStride = sizeof(Vertex);
    asGeometry.geometry.triangles.indexType = VK_INDEX_TYPE_UINT16;
    asGeometry.geometry.triangles.indexData.deviceAddress = getBufferDeviceAddress(indexBuffer);

    // 4. Get Build Size Info
    VkAccelerationStructureBuildGeometryInfoKHR buildInfo{};
    buildInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR;
    buildInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR;
    buildInfo.flags = VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_TRACE_BIT_KHR;
    buildInfo.geometryCount = 1;
    buildInfo.pGeometries = &asGeometry;
    buildInfo.dstAccelerationStructure = blas.handle;

    uint32_t primitiveCount = static_cast<uint32_t>(indices.size() / 3); // Triangle count

    VkAccelerationStructureBuildSizesInfoKHR sizeInfo{};
    sizeInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR;

    vkGetAccelerationStructureBuildSizesKHRA(
        vkInst.device,
        VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR,
        &buildInfo,
        &primitiveCount,  // FIXED: Correct pointer usage
        &sizeInfo);

    // 5. Create Acceleration Structure Buffer
    VkBuffer asBuffer;
    VkDeviceMemory asMemory;
    createBuffer(sizeInfo.accelerationStructureSize,
        VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
        VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
        asBuffer, asMemory);

    // 6. Create Acceleration Structure
    VkAccelerationStructureCreateInfoKHR createInfo{};
    createInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR;
    createInfo.buffer = asBuffer;
    createInfo.size = sizeInfo.accelerationStructureSize;
    createInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR;
    createInfo.offset = 0;

    if (vkCreateAccelerationStructureKHRA(vkInst.device, &createInfo, nullptr, &blas.handle) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create BLAS!");
    }

    
    // Create a small scratch buffer used during build of the bottom level acceleration structure
    VkBuffer scratchBuffer;
    VkDeviceMemory scratchMemory;
    VkDeviceSize scratchBufferSize = sizeInfo.buildScratchSize;
    createBuffer(scratchBufferSize,
        VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
        VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
        scratchBuffer, scratchMemory);


    VkBufferDeviceAddressInfoKHR bufferDeviceAddressInfo{};
    bufferDeviceAddressInfo.sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO;
    bufferDeviceAddressInfo.buffer = scratchBuffer;
    uint64_t scratchBufferAddress = vkGetBufferDeviceAddressKHRA(vkInst.device, &bufferDeviceAddressInfo);

    buildInfo.dstAccelerationStructure = blas.handle;
    buildInfo.scratchData.deviceAddress = scratchBufferAddress;

    // 7. Get Device Address of the BLAS
    VkAccelerationStructureDeviceAddressInfoKHR accelerationDeviceAddressInfo{};
    accelerationDeviceAddressInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR;
    accelerationDeviceAddressInfo.accelerationStructure = blas.handle;
    blas.deviceAddress = vkGetAccelerationStructureDeviceAddressKHRA(vkInst.device, &accelerationDeviceAddressInfo);

    // 8. Build Acceleration Structure
    VkCommandBuffer commandBuffer = beginSingleTimeCommands();

    VkAccelerationStructureBuildRangeInfoKHR rangeInfo{};
    rangeInfo.primitiveCount = primitiveCount;
    rangeInfo.primitiveOffset = 0;
    rangeInfo.firstVertex = 0;
    rangeInfo.transformOffset = 0;

    const VkAccelerationStructureBuildRangeInfoKHR* pRangeInfo = &rangeInfo;

    vkCmdBuildAccelerationStructuresKHRA(
        commandBuffer,
        1, // Number of structures
        &buildInfo,
        &pRangeInfo);

    endSingleTimeCommands(commandBuffer);

    // 9. Clean-up Temporary Buffers
    vkDestroyBuffer(vkInst.device, vertexBuffer, nullptr);
    vkFreeMemory(vkInst.device, vertexMemory, nullptr);
    vkDestroyBuffer(vkInst.device, indexBuffer, nullptr);
    vkFreeMemory(vkInst.device, indexMemory, nullptr);
}




void createTopLevelAS() {
    // Step 1: Allocate buffer for instance data
    VkTransformMatrixKHR identityMatrix = { {
        1.0f, 0.0f, 0.0f, 0.0f,
        0.0f, 1.0f, 0.0f, 0.0f,
        0.0f, 0.0f, 1.0f, 0.0f
    } };

    VkAccelerationStructureInstanceKHR instance{};
    instance.transform = identityMatrix;  // Transformation (identity)
    instance.instanceCustomIndex = 0;
    instance.mask = 0xFF;
    instance.instanceShaderBindingTableRecordOffset = 0;
    instance.flags = VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR;
    instance.accelerationStructureReference = blas.deviceAddress;

    // Step 2: Create instance buffer
    VkBuffer instanceBuffer;
    VkDeviceMemory instanceMemory;
    VkDeviceSize instanceBufferSize = sizeof(VkAccelerationStructureInstanceKHR);

    createBuffer(instanceBufferSize,
        VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
        VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR,
        VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
        instanceBuffer, instanceMemory, &instance);

    // Step 3: Get the device address of the instance buffer
    VkBufferDeviceAddressInfo bufferAddressInfo{};
    bufferAddressInfo.sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO;
    bufferAddressInfo.buffer = instanceBuffer;
    VkDeviceAddress instanceBufferAddress = vkGetBufferDeviceAddress(vkInst.device, &bufferAddressInfo);

    // Step 4: Create TLAS Build Information
    VkAccelerationStructureGeometryKHR geometry{};
    geometry.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR;
    geometry.geometryType = VK_GEOMETRY_TYPE_INSTANCES_KHR;
    geometry.geometry.instances.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_INSTANCES_DATA_KHR;
    geometry.geometry.instances.arrayOfPointers = VK_FALSE;
    geometry.geometry.instances.data.deviceAddress = instanceBufferAddress;

    VkAccelerationStructureBuildRangeInfoKHR buildRangeInfo{};
    buildRangeInfo.primitiveCount = 1;
    buildRangeInfo.primitiveOffset = 0;
    buildRangeInfo.firstVertex = 0;
    buildRangeInfo.transformOffset = 0;

    VkAccelerationStructureBuildGeometryInfoKHR buildInfo{};
    buildInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR;
    buildInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR;
    buildInfo.flags = VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_TRACE_BIT_KHR;
    buildInfo.mode = VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
    buildInfo.srcAccelerationStructure = VK_NULL_HANDLE;
    buildInfo.dstAccelerationStructure = tlas.handle;
    buildInfo.geometryCount = 1;
    buildInfo.pGeometries = &geometry;

    // Get build sizes
    VkAccelerationStructureBuildSizesInfoKHR sizeInfo{};
    sizeInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR;
    vkGetAccelerationStructureBuildSizesKHRA(vkInst.device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR,
        &buildInfo, &buildRangeInfo.primitiveCount, &sizeInfo);

    // Step 5: Allocate buffer for TLAS
    VkBuffer tlasBuffer;
    VkDeviceMemory tlasMemory;
    createBuffer(sizeInfo.accelerationStructureSize,
        VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
        VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, tlasBuffer, tlasMemory);

    // Create TLAS
    VkAccelerationStructureCreateInfoKHR createInfo{};
    createInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR;
    createInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR;
    createInfo.size = sizeInfo.accelerationStructureSize;
    createInfo.buffer = tlasBuffer;
    createInfo.offset = 0;  

    if (vkCreateAccelerationStructureKHRA(vkInst.device, &createInfo, nullptr, &tlas.handle) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create TLAS!");
    }

    buildInfo.dstAccelerationStructure = tlas.handle;

    // Step 6: Allocate scratch buffer
    VkBuffer scratchBuffer;
    VkDeviceMemory scratchMemory;
    createBuffer(sizeInfo.buildScratchSize,
        VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
        VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, scratchBuffer, scratchMemory);

    // Get scratch buffer address
    bufferAddressInfo.buffer = scratchBuffer;
    buildInfo.scratchData.deviceAddress = vkGetBufferDeviceAddress(vkInst.device, &bufferAddressInfo);

    // Step 7: Record and Execute TLAS Build Commands
    VkCommandBuffer commandBuffer = beginSingleTimeCommands();

    const VkAccelerationStructureBuildRangeInfoKHR* pBuildRangeInfos[] = { &buildRangeInfo };
    vkCmdBuildAccelerationStructuresKHRA(commandBuffer, 1, &buildInfo, pBuildRangeInfos);

    endSingleTimeCommands(commandBuffer);

    // Cleanup scratch buffer
    vkDestroyBuffer(vkInst.device, scratchBuffer, nullptr);
    vkFreeMemory(vkInst.device, scratchMemory, nullptr);
}



VkShaderModule createShaderModule(const std::string& filename) {
    // Read the SPIR-V file into a byte buffer
    std::ifstream file(filename, std::ios::binary | std::ios::ate);
    if (!file.is_open()) {
        throw std::runtime_error("Failed to open shader file: " + filename);
    }

    size_t fileSize = (size_t)file.tellg();
    file.seekg(0, std::ios::beg);

    std::vector<char> buffer(fileSize);
    file.read(buffer.data(), fileSize);
    file.close();

    // Create a Vulkan shader module
    VkShaderModuleCreateInfo createInfo = {};
    createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
    createInfo.codeSize = buffer.size();
    createInfo.pCode = reinterpret_cast<const uint32_t*>(buffer.data());

    VkShaderModule shaderModule;
    if (vkCreateShaderModule(vkInst.device, &createInfo, nullptr, &shaderModule) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create shader module from file: " + filename);
    }

    return shaderModule;
}

void createImage(uint32_t width, uint32_t height, VkFormat format, VkImageTiling tiling,
    VkImageUsageFlags usage, VkMemoryPropertyFlags properties,
    VkImage& image, VkDeviceMemory& imageMemory, uint32_t initialLayout = VK_IMAGE_LAYOUT_UNDEFINED) {

    VkImageCreateInfo imageInfo{};
    imageInfo.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
    imageInfo.imageType = VK_IMAGE_TYPE_2D;
    imageInfo.extent.width = width;
    imageInfo.extent.height = height;
    imageInfo.extent.depth = 1;
    imageInfo.mipLevels = 1;
    imageInfo.arrayLayers = 1;
    imageInfo.format = format;
    imageInfo.tiling = tiling;
    imageInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED;
    imageInfo.usage = usage;
    imageInfo.samples = VK_SAMPLE_COUNT_1_BIT;
    imageInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;

    if (vkCreateImage(vkInst.device, &imageInfo, nullptr, &image) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create image!");
    }

    VkMemoryRequirements memRequirements;
    vkGetImageMemoryRequirements(vkInst.device, image, &memRequirements);

    VkMemoryAllocateInfo allocInfo{};
    allocInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
    allocInfo.allocationSize = memRequirements.size;
    allocInfo.memoryTypeIndex = findMemoryType(memRequirements.memoryTypeBits, properties);

    if (vkAllocateMemory(vkInst.device, &allocInfo, nullptr, &imageMemory) != VK_SUCCESS) {
        throw std::runtime_error("Failed to allocate image memory!");
    }

    vkBindImageMemory(vkInst.device, image, imageMemory, 0);
}

VkImageView createImageView(VkImage image, VkFormat format, VkImageAspectFlags aspectFlags) {
    VkImageViewCreateInfo viewInfo{};
    viewInfo.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO;
    viewInfo.image = image;
    viewInfo.viewType = VK_IMAGE_VIEW_TYPE_2D;
    viewInfo.format = format;
    viewInfo.subresourceRange.aspectMask = aspectFlags;
    viewInfo.subresourceRange.baseMipLevel = 0;
    viewInfo.subresourceRange.levelCount = 1;
    viewInfo.subresourceRange.baseArrayLayer = 0;
    viewInfo.subresourceRange.layerCount = 1;

    VkImageView imageView;
    if (vkCreateImageView(vkInst.device, &viewInfo, nullptr, &imageView) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create image view!");
    }
    return imageView;
}




void transitionImageLayout(VkCommandBuffer commandBuffer, VkImage image, VkFormat format, VkImageLayout oldLayout, VkImageLayout newLayout) {
    
    bool oneTime = false;
    if (commandBuffer == NULL)
    {
        commandBuffer = beginSingleTimeCommands();
        oneTime = true;
    }

    VkImageMemoryBarrier barrier{};
    barrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
    barrier.oldLayout = oldLayout;
    barrier.newLayout = newLayout;
    barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
    barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
    barrier.image = image;

    barrier.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
    barrier.subresourceRange.baseMipLevel = 0;
    barrier.subresourceRange.levelCount = 1;
    barrier.subresourceRange.baseArrayLayer = 0;
    barrier.subresourceRange.layerCount = 1;

    VkPipelineStageFlags sourceStage;
    VkPipelineStageFlags destinationStage;

    //if (oldLayout == VK_IMAGE_LAYOUT_UNDEFINED && newLayout == VK_IMAGE_LAYOUT_GENERAL) {
        barrier.srcAccessMask = 0;
        barrier.dstAccessMask = VK_ACCESS_SHADER_WRITE_BIT;

        sourceStage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
        destinationStage = VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR;
   // }
   // else {
   //     throw std::invalid_argument("Unsupported layout transition!");
   // }

    vkCmdPipelineBarrier(commandBuffer,
        sourceStage, destinationStage,
        0,
        0, nullptr,
        0, nullptr,
        1, &barrier);

    if (oneTime)
    {
        endSingleTimeCommands(commandBuffer);
    }
}


void createVulkanDescriptors()
{
    createImage(swapchain.extent.width, swapchain.extent.height,
        swapchain.imageFormat,
        VK_IMAGE_TILING_OPTIMAL,
        VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT,
        VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
        descriptors.storageImage, descriptors.storageImageMemory);

    descriptors.storageImageView = createImageView(descriptors.storageImage, swapchain.imageFormat, VK_IMAGE_ASPECT_COLOR_BIT);

    // ----
    std::vector<VkDescriptorPoolSize> poolSizes = {
    { VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1 },
    { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,              1 }
    };

    VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
    descriptorPoolCreateInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
    descriptorPoolCreateInfo.maxSets = 1;
    descriptorPoolCreateInfo.poolSizeCount = static_cast<uint32_t>(poolSizes.size());
    descriptorPoolCreateInfo.pPoolSizes = poolSizes.data();

    vkCreateDescriptorPool(vkInst.device, &descriptorPoolCreateInfo, nullptr, &descriptors.descriptorPool);

    VkDescriptorSetLayoutBinding accelerationStructureLayoutBinding{};
    accelerationStructureLayoutBinding.binding = 0;
    accelerationStructureLayoutBinding.descriptorType = VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR;
    accelerationStructureLayoutBinding.descriptorCount = 1;
    accelerationStructureLayoutBinding.stageFlags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;

    VkDescriptorSetLayoutBinding outputImageLayoutBinding{};
    outputImageLayoutBinding.binding = 1;
    outputImageLayoutBinding.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
    outputImageLayoutBinding.descriptorCount = 1;
    outputImageLayoutBinding.stageFlags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;

    std::vector<VkDescriptorSetLayoutBinding> bindings({
        accelerationStructureLayoutBinding,
        outputImageLayoutBinding,
    });

    VkDescriptorSetLayoutCreateInfo descriptorSetlayoutCI{};
    descriptorSetlayoutCI.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
    descriptorSetlayoutCI.bindingCount = static_cast<uint32_t>(bindings.size());
    descriptorSetlayoutCI.pBindings = bindings.data();
    vkCreateDescriptorSetLayout(vkInst.device, &descriptorSetlayoutCI, nullptr, &descriptors.descriptorSetLayout);

    VkDescriptorSetAllocateInfo descriptorSetAllocateInfo{};
    descriptorSetAllocateInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
    descriptorSetAllocateInfo.descriptorPool = descriptors.descriptorPool;
    descriptorSetAllocateInfo.pSetLayouts = &descriptors.descriptorSetLayout;
    descriptorSetAllocateInfo.descriptorSetCount = 1;

    vkAllocateDescriptorSets(vkInst.device, &descriptorSetAllocateInfo, &descriptors.descriptorSet);

    VkWriteDescriptorSetAccelerationStructureKHR descriptorAccelerationStructureInfo{};
    descriptorAccelerationStructureInfo.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR;
    descriptorAccelerationStructureInfo.accelerationStructureCount = 1;
    descriptorAccelerationStructureInfo.pAccelerationStructures = &tlas.handle;

    VkWriteDescriptorSet accelerationStructureWrite{};
    accelerationStructureWrite.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
    // The specialized acceleration structure descriptor has to be chained
    accelerationStructureWrite.pNext = &descriptorAccelerationStructureInfo;
    accelerationStructureWrite.dstSet = descriptors.descriptorSet;
    accelerationStructureWrite.dstBinding = 0;
    accelerationStructureWrite.descriptorCount = 1;
    accelerationStructureWrite.descriptorType = VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR;

    VkDescriptorImageInfo storageImageDescriptor{};
    storageImageDescriptor.imageView = descriptors.storageImageView;
    storageImageDescriptor.imageLayout = VK_IMAGE_LAYOUT_GENERAL;

    VkWriteDescriptorSet storageImageStructureWrite{};
    storageImageStructureWrite.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
    storageImageStructureWrite.dstSet = descriptors.descriptorSet;
    storageImageStructureWrite.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
    storageImageStructureWrite.dstBinding = 1;
    storageImageStructureWrite.pImageInfo = &storageImageDescriptor;;
    storageImageStructureWrite.descriptorCount = 1;

 
    std::vector<VkWriteDescriptorSet> writeDescriptorSets = {
        accelerationStructureWrite,
        storageImageStructureWrite
    };
    vkUpdateDescriptorSets(vkInst.device, static_cast<uint32_t>(writeDescriptorSets.size()), writeDescriptorSets.data(), 0, VK_NULL_HANDLE);


}

void createRayTracingPipeline() {
    // Step: Load shader modules
    VkShaderModule raygenShaderModule = createShaderModule("raygen.rgen.spv");
    VkShaderModule missShaderModule = createShaderModule("miss.rmiss.spv");
    VkShaderModule hitShaderModule = createShaderModule("closesthit.rchit.spv");

    // Step : Create pipeline layout (use descriptor set layout)
    VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo{};
    pipelineLayoutCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
    pipelineLayoutCreateInfo.setLayoutCount = 1;
    pipelineLayoutCreateInfo.pSetLayouts = &descriptors.descriptorSetLayout;

    if (vkCreatePipelineLayout(vkInst.device, &pipelineLayoutCreateInfo, nullptr, &rtPipeline.layout) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create pipeline layout!");
    }

    // Step: Create shader stage information
    VkPipelineShaderStageCreateInfo raygenShaderStageInfo{};
    raygenShaderStageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    raygenShaderStageInfo.stage = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
    raygenShaderStageInfo.module = raygenShaderModule;
    raygenShaderStageInfo.pName = "main";

    VkPipelineShaderStageCreateInfo missShaderStageInfo{};
    missShaderStageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    missShaderStageInfo.stage = VK_SHADER_STAGE_MISS_BIT_KHR;
    missShaderStageInfo.module = missShaderModule;
    missShaderStageInfo.pName = "main";

    VkPipelineShaderStageCreateInfo hitShaderStageInfo{};
    hitShaderStageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    hitShaderStageInfo.stage = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
    hitShaderStageInfo.module = hitShaderModule;
    hitShaderStageInfo.pName = "main";

    std::array<VkPipelineShaderStageCreateInfo, 3> shaderStages = { raygenShaderStageInfo, missShaderStageInfo, hitShaderStageInfo };

    // Step 5: Define shader groups
    VkRayTracingShaderGroupCreateInfoKHR raygenGroup{};
    raygenGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR;
    raygenGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
    raygenGroup.generalShader = 0;  // Raygen shader index in pStages array
    raygenGroup.closestHitShader = VK_SHADER_UNUSED_KHR;
    raygenGroup.anyHitShader = VK_SHADER_UNUSED_KHR;
    raygenGroup.intersectionShader = VK_SHADER_UNUSED_KHR;

    VkRayTracingShaderGroupCreateInfoKHR missGroup{};
    missGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR;
    missGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
    missGroup.generalShader = 1;  // Miss shader index in pStages array
    missGroup.closestHitShader = VK_SHADER_UNUSED_KHR;
    missGroup.anyHitShader = VK_SHADER_UNUSED_KHR;
    missGroup.intersectionShader = VK_SHADER_UNUSED_KHR;

    VkRayTracingShaderGroupCreateInfoKHR hitGroup{};
    hitGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR;
    hitGroup.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
    hitGroup.generalShader = VK_SHADER_UNUSED_KHR;  // Hit group does not use generalShader
    hitGroup.closestHitShader = 2;  // Closest hit shader index in pStages array
    hitGroup.anyHitShader = VK_SHADER_UNUSED_KHR;
    hitGroup.intersectionShader = VK_SHADER_UNUSED_KHR;

    std::array<VkRayTracingShaderGroupCreateInfoKHR, 3> shaderGroups = { raygenGroup, missGroup, hitGroup };

    // Step 6: Create the ray tracing pipeline
    VkRayTracingPipelineCreateInfoKHR rayTracingPipelineCreateInfo{};
    rayTracingPipelineCreateInfo.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
    rayTracingPipelineCreateInfo.stageCount = static_cast<uint32_t>(shaderStages.size());
    rayTracingPipelineCreateInfo.pStages = shaderStages.data();
    rayTracingPipelineCreateInfo.groupCount = static_cast<uint32_t>(shaderGroups.size());
    rayTracingPipelineCreateInfo.pGroups = shaderGroups.data();
    rayTracingPipelineCreateInfo.maxPipelineRayRecursionDepth = 3;
    rayTracingPipelineCreateInfo.layout = rtPipeline.layout; 

    if (vkCreateRayTracingPipelinesKHRA(vkInst.device, VK_NULL_HANDLE, VK_NULL_HANDLE, 1, &rayTracingPipelineCreateInfo, nullptr, &rtPipeline.pipeline) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create ray tracing pipeline!");
    }

    // Cleanup
    vkDestroyShaderModule(vkInst.device, raygenShaderModule, nullptr);
    vkDestroyShaderModule(vkInst.device, missShaderModule, nullptr);
    vkDestroyShaderModule(vkInst.device, hitShaderModule, nullptr);
}



uint32_t alignedSize(uint32_t value, uint32_t alignment)
{
    return (value + alignment - 1) & ~(alignment - 1);
}


void createShaderBindingTable() {
    const uint32_t handleSize        = vkInst.rayTracingPipelineProperties.shaderGroupHandleSize;

    const uint32_t handleSizeAligned = alignedSize(vkInst.rayTracingPipelineProperties.shaderGroupHandleSize,
        std::max(vkInst.rayTracingPipelineProperties.shaderGroupHandleAlignment,
            vkInst.rayTracingPipelineProperties.shaderGroupBaseAlignment));

    const uint32_t groupCount = 3; // Raygen, miss, and hit groups
    const uint32_t sbtSize = groupCount * handleSizeAligned;

    // Retrieve the shader group handles
    std::vector<uint8_t> shaderHandleStorage(sbtSize);
    VkResult result = vkGetRayTracingShaderGroupHandlesKHRA(vkInst.device, rtPipeline.pipeline, 0, groupCount, sbtSize, shaderHandleStorage.data());
    if (result != VK_SUCCESS) {
        throw std::runtime_error("Failed to get ray tracing shader group handles.");
    }

    // Create buffer for the SBT
    createBuffer(sbtSize, 
        VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
        VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, sbt.buffer, sbt.memory);

    void* data;
    vkMapMemory(vkInst.device, sbt.memory, 0, sbtSize, 0, &data);
    memcpy(data, shaderHandleStorage.data(), (size_t)sbtSize);
    vkUnmapMemory(vkInst.device, sbt.memory);

    // Set up SBT regions
    VkBufferDeviceAddressInfo bufferAddressInfo{};
    bufferAddressInfo.sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO;
    bufferAddressInfo.buffer = sbt.buffer;
    VkDeviceAddress sbtAddress = vkGetBufferDeviceAddressKHRA(vkInst.device, &bufferAddressInfo);
    
    sbt.raygenRegion.deviceAddress = sbtAddress;
    sbt.raygenRegion.stride = handleSizeAligned;
    sbt.raygenRegion.size = handleSizeAligned;

    sbt.missRegion.deviceAddress = sbtAddress + handleSizeAligned;
    sbt.missRegion.stride = handleSizeAligned;
    sbt.missRegion.size = handleSizeAligned;

    sbt.hitRegion.deviceAddress = sbtAddress + 2 * handleSizeAligned;
    sbt.hitRegion.stride = handleSizeAligned;
    sbt.hitRegion.size = handleSizeAligned;
}


void createCommandBuffers() {
    
    /*
    uint32_t queueFamilyIndex = 0;

    VkCommandPoolCreateInfo poolInfo{};
    poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
    poolInfo.queueFamilyIndex = queueFamilyIndex;
    poolInfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;  // Allow command buffers to be individually reset

    if (vkCreateCommandPool(vkInst.device, &poolInfo, nullptr, &commandBuffers.commandPool) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create command pool!");
    }


    // Allocate command buffers
    commandBuffers.commandBuffers.resize( swapchain.images.size() );
    */
    VkCommandBufferAllocateInfo allocInfo{};
    allocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    allocInfo.commandPool = commandBuffers.commandPool;
    allocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    allocInfo.commandBufferCount = static_cast<uint32_t>(commandBuffers.commandBuffers.size());

    if (vkAllocateCommandBuffers(vkInst.device, &allocInfo, commandBuffers.commandBuffers.data()) != VK_SUCCESS) {
        throw std::runtime_error("Failed to allocate command buffers.");
    }

    // Transition the image to VK_IMAGE_LAYOUT_GENERAL before using it
    transitionImageLayout(NULL, descriptors.storageImage, swapchain.imageFormat, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_GENERAL);

    // Begin command buffer recording
    VkCommandBufferBeginInfo beginInfo{};
    beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    beginInfo.flags = VK_COMMAND_BUFFER_USAGE_SIMULTANEOUS_USE_BIT;


    for (int32_t i = 0; i < commandBuffers.commandBuffers.size(); ++i)
    {
        VkCommandBuffer cmdBuffer = commandBuffers.commandBuffers[i];

        if (vkBeginCommandBuffer(cmdBuffer, &beginInfo) != VK_SUCCESS) {
            throw std::runtime_error("Failed to begin recording command buffer.");
        }

        // Bind ray tracing pipeline
        vkCmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, rtPipeline.pipeline);

        // Bind descriptor set
        vkCmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR,
            rtPipeline.layout, 0, 1, &descriptors.descriptorSet, 0, nullptr);

        VkStridedDeviceAddressRegionKHR callableShaderSbtEntry{};

        // Dispatch rays using the Shader Binding Table (SBT)
        vkCmdTraceRaysKHRA(cmdBuffer,
            &sbt.raygenRegion,
            &sbt.missRegion,
            &sbt.hitRegion,
            &callableShaderSbtEntry,
            swapchain.extent.width,
            swapchain.extent.height, 1);


        transitionImageLayout(cmdBuffer, swapchain.images[i], swapchain.imageFormat, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL);

        transitionImageLayout(cmdBuffer, descriptors.storageImage, swapchain.imageFormat, VK_IMAGE_LAYOUT_GENERAL, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL);

        VkImageCopy copyRegion{};
        copyRegion.srcSubresource = { VK_IMAGE_ASPECT_COLOR_BIT, 0, 0, 1 };
        copyRegion.srcOffset = { 0, 0, 0 };
        copyRegion.dstSubresource = { VK_IMAGE_ASPECT_COLOR_BIT, 0, 0, 1 };
        copyRegion.dstOffset = { 0, 0, 0 };
        copyRegion.extent = { swapchain.extent.width, swapchain.extent.height, 1 };
        vkCmdCopyImage(cmdBuffer, descriptors.storageImage, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, swapchain.images[i], VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, &copyRegion);

        transitionImageLayout(cmdBuffer, swapchain.images[i], swapchain.imageFormat, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR);

        transitionImageLayout(cmdBuffer, descriptors.storageImage, swapchain.imageFormat, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL);

        if (vkEndCommandBuffer(cmdBuffer) != VK_SUCCESS) {
            throw std::runtime_error("Failed to record command buffer.");
        }
    }
}


VkSemaphore imageAvailableSemaphore;
VkSemaphore renderFinishedSemaphore;

void createSemaphores() {
    VkSemaphoreCreateInfo semaphoreInfo{};
    semaphoreInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;

    if (vkCreateSemaphore(vkInst.device, &semaphoreInfo, nullptr, &imageAvailableSemaphore) != VK_SUCCESS ||
        vkCreateSemaphore(vkInst.device, &semaphoreInfo, nullptr, &renderFinishedSemaphore) != VK_SUCCESS) {
        throw std::runtime_error("Failed to create semaphores!");
    }
}

void drawFrame() {
   
    //transitionImageLayout(swapchain.images[0], swapchain.imageFormat, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR);
    //transitionImageLayout(swapchain.images[1], swapchain.imageFormat, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR);
    //transitionImageLayout(swapchain.images[2], swapchain.imageFormat, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR);

    // Acquire an image from the swapchain
    uint32_t imageIndex;
    VkResult result = vkAcquireNextImageKHR(vkInst.device, swapchain.swapchain, UINT64_MAX,
        imageAvailableSemaphore, VK_NULL_HANDLE, &imageIndex);

    // Submit ray tracing command buffer
    VkSubmitInfo submitInfo{};
    submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;

    VkSemaphore waitSemaphores[] = { imageAvailableSemaphore };
    VkPipelineStageFlags waitStages[] = { VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT };
    submitInfo.waitSemaphoreCount = 1;
    submitInfo.pWaitSemaphores = waitSemaphores;
    submitInfo.pWaitDstStageMask = waitStages;

    submitInfo.commandBufferCount = 1;
    submitInfo.pCommandBuffers = &commandBuffers.commandBuffers[imageIndex];

    VkSemaphore signalSemaphores[] = { renderFinishedSemaphore };
    submitInfo.signalSemaphoreCount = 1;
    submitInfo.pSignalSemaphores = signalSemaphores;

    if (vkQueueSubmit(vkInst.graphicsQueue, 1, &submitInfo, VK_NULL_HANDLE) != VK_SUCCESS) {
        throw std::runtime_error("Failed to submit ray tracing command buffer.");
    }

    // Present the image to the swapchain
    VkPresentInfoKHR presentInfo{};
    presentInfo.sType = VK_STRUCTURE_TYPE_PRESENT_INFO_KHR;
    presentInfo.waitSemaphoreCount = 1;
    presentInfo.pWaitSemaphores = signalSemaphores;
    presentInfo.swapchainCount = 1;
    presentInfo.pSwapchains = &swapchain.swapchain;
    presentInfo.pImageIndices = &imageIndex;

    result = vkQueuePresentKHR(vkInst.graphicsQueue, &presentInfo);

    vkQueueWaitIdle(vkInst.graphicsQueue);
}


void cleanup() {
    // Wait for the device to be idle before cleanup
    vkDeviceWaitIdle(vkInst.device);

    // Destroy Shader Binding Table
    if (sbt.buffer != VK_NULL_HANDLE) {
        vkDestroyBuffer(vkInst.device, sbt.buffer, nullptr);
    }
    if (sbt.memory != VK_NULL_HANDLE) {
        vkFreeMemory(vkInst.device, sbt.memory, nullptr);
    }

    // Destroy Ray Tracing Pipeline
    if (rtPipeline.pipeline != VK_NULL_HANDLE) {
        vkDestroyPipeline(vkInst.device, rtPipeline.pipeline, nullptr);
    }
    if (rtPipeline.layout != VK_NULL_HANDLE) {
        vkDestroyPipelineLayout(vkInst.device, rtPipeline.layout, nullptr);
    }
    if (descriptors.descriptorSetLayout != VK_NULL_HANDLE) {
        vkDestroyDescriptorSetLayout(vkInst.device, descriptors.descriptorSetLayout, nullptr);
    }
    if (descriptors.descriptorPool != VK_NULL_HANDLE) {
        vkDestroyDescriptorPool(vkInst.device, descriptors.descriptorPool, nullptr);
    }

    // Destroy Acceleration Structures (BLAS & TLAS)
    if (blas.handle != VK_NULL_HANDLE) {
        vkDestroyAccelerationStructureKHRA(vkInst.device, blas.handle, nullptr);
    }
    if (blas.buffer != VK_NULL_HANDLE) {
        vkDestroyBuffer(vkInst.device, blas.buffer, nullptr);
    }
    if (blas.memory != VK_NULL_HANDLE) {
        vkFreeMemory(vkInst.device, blas.memory, nullptr);
    }

    if (tlas.handle != VK_NULL_HANDLE) {
        vkDestroyAccelerationStructureKHRA(vkInst.device, tlas.handle, nullptr);
    }
    if (tlas.buffer != VK_NULL_HANDLE) {
        vkDestroyBuffer(vkInst.device, tlas.buffer, nullptr);
    }
    if (tlas.memory != VK_NULL_HANDLE) {
        vkFreeMemory(vkInst.device, tlas.memory, nullptr);
    }

    // Destroy Command Buffers & Pool
    if (!commandBuffers.commandBuffers.empty()) {
        vkFreeCommandBuffers(vkInst.device, commandBuffers.commandPool,
                             static_cast<uint32_t>(commandBuffers.commandBuffers.size()), commandBuffers.commandBuffers.data());
    }
    if (commandBuffers.commandPool != VK_NULL_HANDLE) {
        vkDestroyCommandPool(vkInst.device, commandBuffers.commandPool, nullptr);
    }

    // Destroy Swapchain
    for (auto imageView : swapchain.imageViews) {
        vkDestroyImageView(vkInst.device, imageView, nullptr);
    }
    if (swapchain.swapchain != VK_NULL_HANDLE) {
        vkDestroySwapchainKHR(vkInst.device, swapchain.swapchain, nullptr);
    }

    // Destroy Vulkan Device
    if (vkInst.device != VK_NULL_HANDLE) {
        vkDestroyDevice(vkInst.device, nullptr);
    }

    // Destroy Vulkan Instance
    if (vkInst.instance != VK_NULL_HANDLE) {
        vkDestroyInstance(vkInst.instance, nullptr);
    }
}


int main() {
    const uint32_t width = 640, height = 480;

    if (!glfwInit()) {
        std::cerr << "Failed to initialize GLFW!" << std::endl;
        return -1;
    }

    glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API);
    GLFWwindow* window = glfwCreateWindow(width, height, "Vulkan Ray Tracing", nullptr, nullptr);
    
    if (!window) {
        std::cerr << "Failed to create GLFW window!" << std::endl;
        glfwTerminate();
        return -1;
    }

    // Test vertex data (position/normal/color)
    const std::vector<Vertex> vertices = {
        { { 0.0f,  0.5f, 0.0f}, {0.0f, 0.0f, -1.0f},  {1.0f, 0.0f, 0.0f}  },
        { {-0.5f, -0.5f, 0.0f}, {0.0f, 0.0f, -1.0f},  {0.0f, 1.0f, 0.0f}  },
        { { 0.5f, -0.5f, 0.0f}, {0.0f, 0.0f, -1.0f},  {0.0f, 0.0f, 1.0f}  } 
    };

    // Index data
    const std::vector<uint16_t> indices = { 0, 1, 2 };

    createVulkanInstance(window);
    createSwapchain(window);
    createBottomLevelAS(vertices, indices);
    createTopLevelAS();
    createVulkanDescriptors();
    createRayTracingPipeline();
    createShaderBindingTable();
    createCommandBuffers();

    createSemaphores();

    while (!glfwWindowShouldClose(window)) {
        glfwPollEvents();
        drawFrame();
    }

    cleanup();
    glfwDestroyWindow(window);
    glfwTerminate();
}


//--------------------------------------------------------------//
//--------------------------------------------------------------//

