// g++ -std=c++11 frequency.cpp -o frequency -I ~/include -L ~/lib -lsdsl -ldivsufsort -ldivsufsort64 -O3 -DNDEBUG

#include <set>
#include <unordered_set>
#include <map>
#include <string>
#include "sys/times.h"
#include <chrono>
#include <vector>
#include <array>
#include <algorithm>
#include <stdio.h>
#include <stdlib.h>
#include <omp.h>

#include <sdsl/int_vector.hpp>
#include <sdsl/bit_vectors.hpp>
#include <sdsl/util.hpp>
#include <sdsl/rank_support.hpp>
#include <sdsl/select_support.hpp>
#include <sdsl/suffix_arrays.hpp>

extern "C" {
        #include "huffman.h"
}

huffman_node_t *readCompressedHuffB2(const std::string path, bit_file_t **bInFile){
    const std::string b2Path = path + ".B2.bin.huffman";

    FILE *inFile = NULL;

    if ((inFile = fopen(b2Path.c_str(), "rb")) == NULL){
        std::cout<<" cannot read B2 huffman file\n";
        exit(1);
    }

    huffman_node_t *ht = NULL;

    HuffmanDecodeToBytes(inFile, bInFile, &ht);


    if(ht == NULL || bInFile == NULL){
            std::cout<<" ht is NULL o bb es NULL\n";
            exit(1);
    }
    return ht;
}

void readCompressed(const std::string path, sdsl::wm_int<sdsl::rrr_vector<15>> &x_wm,
    //sdsl::rrr_vector<63> &b1_rrr, sdsl::wt_hutu<sdsl::rrr_vector<15>> &b2_wt,
    sdsl::rrr_vector<63> &b1_rrr,
    sdsl::wm_int<sdsl::rrr_vector<15>> &y_wm)
{
    // Path to sequences
    const std::string xPath = path + ".X.bin-wm_int.sdsl";
    const std::string b1Path = path + ".B1-rrr-64.sdsl";
    const std::string yPath = path + ".Y.bin.huff.bin-wm_int.sdsl";

    // Read compressed files
    load_from_file(x_wm, xPath.c_str());
    load_from_file(b1_rrr, b1Path.c_str());
    load_from_file(y_wm, yPath.c_str());

    return;
}


std::vector<uint32_t> getNodeNeighbors(sdsl::wm_int<sdsl::rrr_vector<15>> &x_wm,
    sdsl::rrr_vector<63>::rank_1_type &b1_rank, sdsl::rrr_vector<63>::select_1_type &b1_select,
    huffman_node_t *ht, bit_file_t *bInFile, sdsl::wm_int<sdsl::rrr_vector<15>> &yRAM,
    uint32_t ylen,
    uint32_t current_node)
{
    std::vector<uint32_t> neighs;
    uint32_t k=0;
    const uint32_t howManyX = x_wm.rank(x_wm.size(), current_node);

    for (uint32_t xCount = 1; xCount <= howManyX; ++xCount)
    {

        const uint64_t xIndex = x_wm.select(xCount, current_node);
        // std::cerr << "nI" << xIndex << " ";

        uint64_t partitionNumber = b1_rank(xIndex + 1) - 1;
        // std::cerr << "pN " << partitionNumber << " ";

        const uint64_t partitionIndex = b1_select(partitionNumber + 1);
        // std::cerr <<  "pI " << partitionIndex << " ";

        const uint64_t nextPartitionIndex = b1_select(partitionNumber + 2);
        // std::cerr <<  "nPI " << nextPartitionIndex << " ";

        const uint32_t psize = nextPartitionIndex - partitionIndex;
        // std::cerr <<  "hMB " << howManyNodesInPartition << " ";

        uint32_t current_Y, nextp_Y;
        uint32_t bytesPerNode = 1;
	if(partitionNumber < ylen - 1){
		current_Y = yRAM[partitionNumber];
	  	nextp_Y = yRAM[partitionNumber+1];
         	//std::cerr <<  "cur " << current_Y << " nextp "<<nextp_Y<<"\n";
	} else {
		bytesPerNode = 0;
	}
 
        // If no bytes per node, all nodes are adjacent
        if(0 == bytesPerNode)
        {
               for (uint64_t xI = partitionIndex; xI < nextPartitionIndex; ++xI)
               {
                   if(xIndex != xI)
			neighs.emplace_back(x_wm[xI]);
	       }
        }
        else
        {

	    int numb2 = 0;
	    std::vector<unsigned char> b2RAM(nextp_Y-current_Y);
	    unsigned char *ptr = b2RAM.data();
	    #pragma omp critical
	    int ret = HuffmanDecodePartition(&ptr, current_Y, nextp_Y, &numb2, ht, bInFile);
            uint32_t bytesPerNode = numb2/psize;

            const uint64_t currentByteIndex = bytesPerNode * (xIndex - partitionIndex);
            for(uint32_t xI = partitionIndex; xI < nextPartitionIndex; ++xI){
	      if(xI == xIndex)continue;
                 const uint32_t b2xIbyteIndex = bytesPerNode * (xI - partitionIndex);
                 for(uint32_t bytesChecked = 0; bytesChecked<bytesPerNode; ++bytesChecked){
                      const uint8_t maskByteOfCurrent = (uint8_t)b2RAM[currentByteIndex + bytesChecked];
                      const uint8_t maskBytePossibleNeighbor = (uint8_t)b2RAM[b2xIbyteIndex + bytesChecked];

                      if(maskByteOfCurrent & maskBytePossibleNeighbor)
                      {
			 neighs.emplace_back(x_wm[xI]);
                         break;
                      }
                 }
	    }
	    
        } // del if bypesPerNode > 0


        // std::cerr << std::endl;
    }

    return neighs;
}

int main(int argc, char const *argv[])
{
    if(3 > argc)
    {
        std::cerr << "Modo de uso: " << argv[0] << " RUTA_BASE QueryFile.bin " << std::endl;
        return -1;
    }


    const std::string path(argv[1]);

    // Variables to read compressed sequences
    sdsl::wm_int<sdsl::rrr_vector<15>> x_wm;
    sdsl::rrr_vector<63> b1_rrr;
    sdsl::wm_int<sdsl::rrr_vector<15>> y_wm;

    bit_file_t *bInFile = NULL;
    huffman_node_t *ht = readCompressedHuffB2(path, &bInFile); 
    readCompressed(path, x_wm, b1_rrr, y_wm);
    sdsl::rrr_vector<63>::rank_1_type b1_rank(&b1_rrr);
    sdsl::rrr_vector<63>::select_1_type b1_select(&b1_rrr);
    uint32_t ylen = y_wm.size();


    FILE * list_fp = fopen(argv[2],"r");
    uint queries;
    fread(&queries, sizeof(uint), 1, list_fp);
    uint32_t *qry = (uint32_t *) malloc(sizeof(uint32_t)*queries);
    fread(qry,sizeof(uint),queries,list_fp);
    std::cerr<<"Processing "<<queries<<" queries\n";

    ulong recovered = 0;
    std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now();
    for(uint32_t i=0;i<queries;i++) {
        std::unordered_set<uint32_t> neighs;
        std::vector<uint32_t> ne =getNodeNeighbors(x_wm, b1_rank, b1_select, ht, bInFile, y_wm, ylen, qry[i]);
	neighs.insert(ne.begin(), ne.end());
        recovered += neighs.size();
    }
    std::chrono::high_resolution_clock::time_point stop_time = std::chrono::high_resolution_clock::now();

    auto duration = std::chrono::duration_cast<std::chrono::milliseconds> (stop_time - start_time).count();

    std::cerr <<"total time iter : " << duration << " [ms]" << " \n";
    std::cerr <<"total queries "<<queries << " \n";
    std::cerr <<"recovered edges "<<recovered<<" time per link "<<1.0*duration/recovered<< std::endl;
    std::cerr <<"time per query "<<1.0*duration/queries << " \n";


    return 0;
}
