#include "CUDADeviceScopeGuard.h"
#include "CUDADriverInfo.h"
#include "CUDAEventTimer.h"
#include "CUDAStreamsHandler.h"
#include "CUDAUtilityDeviceFunctions.h"
#include "CPUParallelism/CPUParallelismNCP.h"
#include "UtilityFunctions.h"
#include <driver_types.h>
#include <benchmark/benchmark.h>
#include <memory>

using namespace std;
using namespace UtilsCUDA;
using namespace Utils::UtilityFunctions;
using namespace Utils::CPUParallelism;

void BM_Concurrent_Memory_Allocation(benchmark::State& state)
{
  constexpr size_t numberOfCUDAStreams = 1;
  const CUDADriverInfo cudaDriverInfo(cudaDeviceScheduleAuto, true);

  const size_t numberOfAllocations  = state.range(0);
  const size_t totalAllocationSpace = 1ull << 31; // 2Gb of VRAM
  const size_t allocationSize       = totalAllocationSpace / numberOfAllocations;

  // initialize allocation data
  const auto arrays = make_unique<uint8_t*[]>(numberOfAllocations);
  DebugConsole_consoleOutLine("Number of allocations: ", numberOfAllocations);
  DebugConsole_consoleOutLine("Allocation size: ",       allocationSize);

  for (int device = 0; device < cudaDriverInfo.getDeviceCount(); ++device)
  {
    const CUDAStreamsHandler streams(cudaDriverInfo, device, numberOfCUDAStreams);
    CUDAEventTimer gpuTimer(device, streams[0]);

    for (auto _ : state)
    {
      // Host registrations/de-registrations below:

      // create allocations for registration
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        arrays[i] = new uint8_t[allocationSize];
      }

      // register host memory in parallel
      gpuTimer.startTimer();
      parallelFor(0, numberOfAllocations, [&](size_t i)
      {
        CUDAError_checkCUDAError(cudaHostRegister(arrays[i], allocationSize, cudaHostRegisterDefault));
      });
      DebugConsole_consoleOutLine("Host concurrent registration: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // un-register host memory in parallel
      gpuTimer.startTimer();
      parallelFor(0, numberOfAllocations, [&](size_t i)
      {
        CUDAError_checkCUDAError(cudaHostUnregister(arrays[i]));
      });
      DebugConsole_consoleOutLine("Host concurrent un-registration: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // register host memory serially
      gpuTimer.startTimer();
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        CUDAError_checkCUDAError(cudaHostRegister(arrays[i], allocationSize, cudaHostRegisterDefault));
      }
      DebugConsole_consoleOutLine("Host serial registration: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // un-register host memory serially
      gpuTimer.startTimer();
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        CUDAError_checkCUDAError(cudaHostUnregister(arrays[i]));
      }
      DebugConsole_consoleOutLine("Host serial un-registration: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // delete allocations for registration
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
          delete[] arrays[i];
      }



      // Host allocations/de-allocations below:

      // allocate host memory in parallel
      gpuTimer.startTimer();
      parallelFor(0, numberOfAllocations, [&](size_t i)
      {
        CUDAError_checkCUDAError(cudaMallocHost(arrays.get() + i, allocationSize, cudaHostRegisterDefault));
      });
      DebugConsole_consoleOutLine("Host concurrent allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // de-allocate host memory in parallel
      gpuTimer.startTimer();
      parallelFor(0, numberOfAllocations, [&](size_t i)
      {
        CUDAError_checkCUDAError(cudaFreeHost(arrays[i]));
      });
      DebugConsole_consoleOutLine("Host concurrent de-allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // allocate host memory serially
      gpuTimer.startTimer();
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        CUDAError_checkCUDAError(cudaMallocHost(arrays.get() + i, allocationSize, cudaHostRegisterDefault));
      }
      DebugConsole_consoleOutLine("Host serial allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // de-allocate host memory serially
      gpuTimer.startTimer();
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        CUDAError_checkCUDAError(cudaFreeHost(arrays[i]));
      }
      DebugConsole_consoleOutLine("Host serial de-allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");



      // Device allocations/de-allocations below:

      // allocate device memory in parallel
      gpuTimer.startTimer();
      parallelFor(0, numberOfAllocations, [&](size_t i)
      {
        // choose which GPU to run the GPU kernel on for a multi-GPU system
        CUDADeviceScopeGuard deviceScopeGuard(device);
        CUDAError_checkCUDAError(cudaMalloc(arrays.get() + i, allocationSize));
      });
      DebugConsole_consoleOutLine("Device concurrent allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // de-allocate device memory in parallel
      gpuTimer.startTimer();
      parallelFor(0, numberOfAllocations, [&](size_t i)
      {
        CUDAError_checkCUDAError(cudaFree(arrays[i]));
      });
      DebugConsole_consoleOutLine("Device concurrent de-allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // allocate device memory serially
      gpuTimer.startTimer();
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        // choose which GPU to run the GPU kernel on for a multi-GPU system
        CUDADeviceScopeGuard deviceScopeGuard(device);
        CUDAError_checkCUDAError(cudaMalloc(arrays.get() + i, allocationSize));
      }
      DebugConsole_consoleOutLine("Device serial allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      // de-allocate device memory serially
      gpuTimer.startTimer();
      for (size_t i = 0; i < numberOfAllocations; ++i)
      {
        CUDAError_checkCUDAError(cudaFree(arrays[i]));
      }
      DebugConsole_consoleOutLine("Device serial de-allocation: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");
    }
  }
}

BENCHMARK(BM_Concurrent_Memory_Allocation)->Arg(1u)->Arg(2u)->Arg(4u)->Arg(8u)->Arg(16u)->Arg(32u)->Arg(64u)->Arg(128u)->Arg(256u)->Arg(512u)->Arg(1024u)->Arg(2048u)->Arg(4096u)->Arg(8192u)->Iterations(8u)->UseRealTime();