#include "Astar.h"
#include "CPUParallelism/CPUParallelismNCP.h"
#include <cmath>
#include <queue>
#include <vector>
#include <array>
#include <limits>

using namespace std;
using namespace Utils;
using namespace Utils::CPUParallelism;
using namespace Utils::Pathfinding;

namespace // anonymous namespace used instead of deprecated 'static' keyword used for cpp variable locality
{
  const size_t NUMBER_OF_NEIGHBOURS_NO_DIAGONALS   = 4;
  const size_t NUMBER_OF_NEIGHBOURS_WITH_DIAGONALS = 8;

  // represents a single grid pixel
  struct Node
  {
    size_t index = 0;    // index in the flattened grid
    float   cost = 0.0f; // cost of traversing this pixel

    Node(size_t i, float c) : index(i), cost(c) {}

    Node()  = default;
    ~Node() = default;
    Node(const Node&) = default;
    Node(Node&&)      = default;
    Node& operator=(const Node&) = default;
    Node& operator=(Node&&)      = default;
  };

  // in STL by default the top of the priority_queue<Node> is the greatest element; we want the smallest cost for faster A* execution, so we flip the sign
  struct NodesCostComparator
  {
    inline bool operator()(const Node& node1, const Node& node2) const
    {
      return (node1.cost > node2.cost);
    }
  };

  // operator == for testing between Nodes by index
  inline bool operator==(const Node& node1, const Node& node2)
  {
    return (node1.index == node2.index);
  }

  inline bool checkTopBoundary(const Node& node, size_t width)
  {
    return (node.index / width > 0);
  }

  inline bool checkLeftBoundary(const Node& node, size_t width)
  {
    return (node.index % width > 0);
  }

  inline bool checkBottomBoundary(const Node& node, size_t width, size_t height)
  {
    return (node.index / width + 1 < height);
  }

  inline bool checkRightBoundary(const Node& node, size_t width)
  {
    return (node.index % width + 1 <  width);
  }

  // Manhattan distance: requires each move to cost >= 1
  inline int64_t manhattanDistance(int64_t i0, int64_t j0, int64_t i1, int64_t j1)
  {
    return abs(i0 - i1) + abs(j0 - j1);
  }

  inline bool executeInternal(size_t width, size_t height, size_t start, size_t goal, const float* __restrict weights, const float* __restrict costsH, float* __restrict totalCosts, size_t* __restrict paths, bool useDiagonals)
  {
    priority_queue<Node, vector<Node>, NodesCostComparator> nodesToVisit;
    nodesToVisit.emplace(Node(start, 0.0f)); // add starting node to queue
    const Node  goalNode(goal, 0.0f);        // keep a goal Node for checks
    totalCosts[start] = 0.0f;                // initialize starting node cost to zero

    array<int64_t, NUMBER_OF_NEIGHBOURS_WITH_DIAGONALS> neighbours       = { { -1 } };    // double braces because we initialize an array inside an std::array object
    array<float,   NUMBER_OF_NEIGHBOURS_WITH_DIAGONALS> neighboursCostsG = { { 10.0f } }; // double braces because we initialize an array inside an std::array object
    const size_t neighboursToCheck                                       = useDiagonals ? NUMBER_OF_NEIGHBOURS_WITH_DIAGONALS : NUMBER_OF_NEIGHBOURS_NO_DIAGONALS;

    if (useDiagonals)
    {
      // non-diagonals get a G value of 10.0f, the diagonals get a value of 14.0f (estimation of 10.0f * sqrt(2.0f))
      neighboursCostsG[4] = neighboursCostsG[5] = neighboursCostsG[6] = neighboursCostsG[7] = 14.0f;
    }

    bool solutionFound = false;
    while (!nodesToVisit.empty())
    {
      const Node currentNode = nodesToVisit.top(); // priority_queue<Node>.top() doesn't actually remove the node
      if (currentNode == goalNode)
      {
        solutionFound = true;
        break;
      }

      nodesToVisit.pop(); // now remove remove the node

      // check bounds for all neighbours
      // first check all non-diagonals
      neighbours[0] = checkTopBoundary(   currentNode, width)         ? int64_t(currentNode.index - width) : -1; //    top neighbour
      neighbours[1] = checkLeftBoundary(  currentNode, width)         ? int64_t(currentNode.index -     1) : -1; //   left neighbour
      neighbours[2] = checkBottomBoundary(currentNode, width, height) ? int64_t(currentNode.index + width) : -1; // bottom neighbour
      neighbours[3] = checkRightBoundary( currentNode, width)         ? int64_t(currentNode.index +     1) : -1; //  right neighbour
      if (useDiagonals)
      {
        // then check all diagonals
        neighbours[4] = checkTopBoundary(   currentNode, width)         && checkLeftBoundary( currentNode, width) ? int64_t(currentNode.index - width - 1) : -1; //      top-left neighbour
        neighbours[5] = checkTopBoundary(   currentNode, width)         && checkRightBoundary(currentNode, width) ? int64_t(currentNode.index - width + 1) : -1; //     top-right neighbour
        neighbours[6] = checkBottomBoundary(currentNode, width, height) && checkRightBoundary(currentNode, width) ? int64_t(currentNode.index + width + 1) : -1; //  bottom-right neighbour
        neighbours[7] = checkBottomBoundary(currentNode, width, height) && checkLeftBoundary( currentNode, width) ? int64_t(currentNode.index + width - 1) : -1; //   bottom-left neighbour
      }
      // run main A* loop
      for (size_t i = 0; i < neighboursToCheck; ++i)
      {
        const int64_t neighbourIndex = neighbours[i];
        if (neighbourIndex >= 0) // skip out-of-boundary neighbours
        {
          // the sum of the cost so far and the cost of this move
          const float currentCost = totalCosts[currentNode.index] + weights[neighbourIndex];
          if (currentCost < totalCosts[neighbourIndex])
          {
            totalCosts[neighbourIndex] = currentCost;
            const float priority       = currentCost + neighboursCostsG[i] + costsH[neighbourIndex];
            // paths with lower expected cost are explored first
            nodesToVisit.emplace(Node(neighbourIndex, priority));
            paths[neighbourIndex] = currentNode.index;
          }
        }
      }
    }

    return solutionFound;
  }

  inline void unravelAndMarkOptimalPathInternal(size_t start, size_t goal, const size_t* __restrict paths, uint8_t* __restrict optimalPath, uint8_t markValue)
  {
    size_t currentPath = goal;
    while (currentPath != start)
    {
      optimalPath[currentPath] = markValue;
      currentPath = paths[currentPath];
    }
  }
}

Astar::Astar(size_t width, size_t height, DiagonalMode diagonalMode) noexcept
  : width_(width)
  , height_(height)
  , arraySize_(width * height)
  , diagonalMode_(diagonalMode)
{
  weights_    = unique_ptr<float[]>(new float[arraySize_]);   // avoid enforcing the default constructor through the make_unique() call as make_unique() is using the C++03 array initialization syntax (only works for primitives)
  costsH_     = unique_ptr<float[]>(new float[arraySize_]);   // avoid enforcing the default constructor through the make_unique() call as make_unique() is using the C++03 array initialization syntax (only works for primitives)
  totalCosts_ = unique_ptr<float[]>(new float[arraySize_]);   // avoid enforcing the default constructor through the make_unique() call as make_unique() is using the C++03 array initialization syntax (only works for primitives)
  paths_      = unique_ptr<size_t[]>(new size_t[arraySize_]); // avoid enforcing the default constructor through the make_unique() call as make_unique() is using the C++03 array initialization syntax (only works for primitives)
}

float Astar::calculateCostH(size_t index) const
{
  return float(manhattanDistance(int64_t(index / width_), int64_t(index % width_), int64_t(goal_ / width_), int64_t(goal_ % width_)));
}

void Astar::resetState() const
{
  parallelFor(0, arraySize_, [&](size_t i)
  {
    totalCosts_[i] = numeric_limits<float>::infinity();
    paths_[i]      = numeric_limits<size_t>::max();
  });
}

void Astar::setStartAndGoal(size_t start, size_t goal)
{
  start_ = start;
  goal_  = goal;

  resetState();
}

void Astar::setStrategy(const FunctionView<float(size_t)>& strategy) const
{
  parallelFor(0, arraySize_, [&](size_t i)
  {
    weights_[i] = strategy(i);
    costsH_[i]  = calculateCostH(i); // precalculate the H cost
  });
}

bool Astar::execute() const
{
  return executeInternal(width_, height_, start_, goal_, weights_.get(), costsH_.get(), totalCosts_.get(), paths_.get(), (diagonalMode_ == DiagonalMode::WITH_DIADONALS));
}

void Astar::unravelAndMarkOptimalPath(uint8_t* __restrict optimalPath, uint8_t markValue) const
{
  unravelAndMarkOptimalPathInternal(start_, goal_, paths_.get(), optimalPath, markValue);
}