#include "IntentiveCacheCleaning.h"
#include "CUDADeviceScopeGuard.h"
#include "CUDADriverInfo.h"
#include "CUDAEventTimer.h"
#include "CUDAMemoryHandlers.h"
#include "CUDAParallelFor.h"
#include "CUDAStreamsHandler.h"
#include "UtilityFunctions.h"
#include <benchmark/benchmark.h>

using namespace std;
using namespace UtilsCUDA;
using namespace UtilsCUDA::CUDAParallelFor;
using namespace Utils::UtilityFunctions;

void BM_Strided_Global_Memory_Access(benchmark::State& state)
{
  constexpr size_t numberOfCUDAStreams = 1;
  const CUDADriverInfo cudaDriverInfo(cudaDeviceScheduleAuto, true);

  const size_t stride           = state.range(0);
  const size_t numberOfElements =    2 * 1024 * 1024;
  const size_t dataSize         = 1024 * 1024 * 1024;

  DebugConsole_consoleOutLine("stride size: ", stride);

  for (int device = 0; device < cudaDriverInfo.getDeviceCount(); ++device)
  {
    const CUDAStreamsHandler streams(cudaDriverInfo, numberOfCUDAStreams, device);
    CUDAEventTimer gpuTimer(device, streams[0]);

    // set up input and output data
    HostDeviceMemory<uint32_t>  inData(numberOfElements * stride, device);
    HostDeviceMemory<uint32_t> outData(numberOfElements * stride, device);

    // clear up L2 GPU cache
    IntentiveCacheCleaning cacheCleaner(dataSize, streams[0], device);
    cacheCleaner.cleanCache();

    for (auto _ : state)
    {
      gpuTimer.startTimer();
      {
        // choose which GPU to run the GPU kernel on for a multi-GPU system
        CUDADeviceScopeGuard deviceScopeGuard(device);
        // read strided data from GPU memory
        launchCUDAParallelForInStream(numberOfElements, 0, streams[0], [] __device__(size_t index, uint32_t* __restrict inData, uint32_t* __restrict outData, size_t stride)
        {
          outData[index] = inData[index * stride];
        }, inData.device(), outData.device(), stride);
      }
      DebugConsole_consoleOutLine("Read strided data from GPU memory: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      cacheCleaner.cleanCache();

      gpuTimer.startTimer();
      {
        // choose which GPU to run the GPU kernel on for a multi-GPU system
        CUDADeviceScopeGuard deviceScopeGuard(device);
        // write strided data from GPU memory
        launchCUDAParallelForInStream(numberOfElements, 0, streams[0], [] __device__(size_t index, uint32_t* __restrict inData, uint32_t* __restrict outData, size_t stride)
        {
          outData[index * stride] = inData[index];
        }, inData.device(), outData.device(), stride);
      }
      DebugConsole_consoleOutLine("Write strided data from GPU memory: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      cacheCleaner.cleanCache();

      gpuTimer.startTimer();
      {
        // choose which GPU to run the GPU kernel on for a multi-GPU system
        CUDADeviceScopeGuard deviceScopeGuard(device);
        // read & write strided data from GPU memory
        launchCUDAParallelForInStream(numberOfElements, 0, streams[0], [] __device__(size_t index, uint32_t* __restrict inData, uint32_t* __restrict outData, size_t stride)
        {
          outData[index * stride] = inData[index * stride];
        }, inData.device(), outData.device(), stride);
      }
      DebugConsole_consoleOutLine("Read & write strided data from GPU memory: ", gpuTimer.getElapsedTimeInMilliSecs(), " ms.");

      cacheCleaner.cleanCache();
    }
  }
}

BENCHMARK(BM_Strided_Global_Memory_Access)->Arg(1u)->Arg(2u)->Arg(4u)->Arg(8u)->Arg(16u)->Arg(32u)->Arg(64u)->Arg(128u)->Iterations(8u)->UseRealTime();