/*
 * File:   common.c
 * Author: Pavel Najman <najman.pavel at gmail.com>
 *
 * Created on April 28, 2017, 8:08 AM
 */

#include "common.h"

#include <time.h>
#include <assert.h>
#include <malloc.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>

#ifndef NO_SSE
    #include <xmmintrin.h>
#endif

//#define round(x) ((x)>=0?(long)((x)+0.5):(long)((x)-0.5))

#ifdef NO_SSE
void allocate_tmp_mem(TmpMem * tmp_mem, size_t num_threads, size_t size_x)
{
    tmp_mem->band_size_x = size_x >> 1;

    size_t band_size = tmp_mem->band_size_x * num_threads;

    tmp_mem->LH = memalign(16, band_size * sizeof(float));
    tmp_mem->HH = memalign(16, band_size * sizeof(float));
    tmp_mem->LL = memalign(16, band_size * sizeof(float));
    tmp_mem->HL = memalign(16, band_size * sizeof(float));
    
    assert(tmp_mem->LH != NULL && tmp_mem->HH != NULL && tmp_mem->LL != NULL &&  tmp_mem->HL != NULL);
}

#else
void allocate_tmp_mem(TmpMem * tmp_mem, size_t num_threads, size_t size_x)
{
    tmp_mem->band_size_x = size_x >> 1;

    size_t band_size = tmp_mem->band_size_x * num_threads;

    tmp_mem->LH = _mm_malloc(band_size * sizeof(float),  16);
    tmp_mem->HH = _mm_malloc(band_size * sizeof(float),  16);
    tmp_mem->LL = _mm_malloc(band_size * sizeof(float),  16);
    tmp_mem->HL = _mm_malloc(band_size * sizeof(float),  16);
    
    assert(tmp_mem->LH != NULL && tmp_mem->HH != NULL && tmp_mem->LL != NULL &&  tmp_mem->HL != NULL);
}
#endif

#ifndef NO_SSE
void put_tmp_mem_LH(TmpMem * tmp_mem, float * src, size_t y)
{
   float *s = src;
   float *d = tmp_mem->LH + y * tmp_mem->band_size_x;
   for(; s < src + tmp_mem->band_size_x; s += 4, d += 4){
        _mm_store_ps(d, _mm_load_ps(s));
   }
}

void put_tmp_mem_HH(TmpMem * tmp_mem, float * src, size_t y)
{
   float *s = src;
   float *d = tmp_mem->HH + y * tmp_mem->band_size_x;
   for(; s < src + tmp_mem->band_size_x; s += 4, d += 4){
        _mm_store_ps(d, _mm_load_ps(s));
   }
}

void put_tmp_mem_LL(TmpMem * tmp_mem, float * src, size_t y)
{
   float *s = src;
   float *d = tmp_mem->LL + y * tmp_mem->band_size_x;
   for(; s < src + tmp_mem->band_size_x; s += 4, d += 4){
        _mm_store_ps(d, _mm_load_ps(s));
   }
}

void put_tmp_mem_HL(TmpMem * tmp_mem, float * src, size_t y)
{
   float *s = src;
   float *d = tmp_mem->HL + y * tmp_mem->band_size_x;
   for(; s < src + tmp_mem->band_size_x; s += 4, d += 4){
        _mm_store_ps(d, _mm_load_ps(s));
   }
}
#endif

void NO_TREE_VECTORIZE put_tmp_mem_LH_no_SSE(TmpMem * tmp_mem, float * src, size_t y)
{
   float *s = src;
   float *d = tmp_mem->LH + y * tmp_mem->band_size_x;
   for(; s < src + tmp_mem->band_size_x; ++s, ++d){
        *d = *s;
   }
}

void NO_TREE_VECTORIZE put_tmp_mem_HH_no_SSE(TmpMem * tmp_mem, float * src, size_t y)
{
   float *s = src;
   float *d = tmp_mem->HH + y * tmp_mem->band_size_x;
   for(; s < src + tmp_mem->band_size_x; ++s, ++d){
        *d = *s;
   }
}

#ifdef NO_SSE
void free_tmp_mem(TmpMem * tmp_mem)
{
    free(tmp_mem->LH);
    free(tmp_mem->HH);
    free(tmp_mem->LL);
    free(tmp_mem->HL);
}
#else
void free_tmp_mem(TmpMem * tmp_mem)
{
    _mm_free(tmp_mem->LH);
    _mm_free(tmp_mem->HH);
    _mm_free(tmp_mem->LL);
    _mm_free(tmp_mem->HL);
}
#endif

void init_threading_info(ThreadingInfo * info)
{
    info->num_sockets = 1;
    info->num_threads = 4;
}

void allocate_image(Image * img, size_t size_x, size_t size_y)
{
    img->size_x = size_x;
    img->size_y = size_y;
    img->stride_y = size_x + IMAGE_ROW_PADDING;
    img->size = img->stride_y * size_y;
    #ifdef NO_SSE
        img->data = memalign(16, img->size * sizeof(float));
    #else
        img->data = _mm_malloc(img->size * sizeof(float), 16);
    #endif

    assert(img->data != NULL);
}

void init_image(Image * img, size_t tile_size_x, size_t tile_size_y)
{
    set_tile_size(img, tile_size_x, tile_size_y);

    for(size_t y = 0; y < img->size_y; ++y){
        for(size_t x = 0; x < img->size_x; ++x){
            img->data[y*img->stride_y + x] = (float)rand() / (float) RAND_MAX;
        }
    }
}

void load_image(const char * filename, Image * img)
{
    FILE * f = fopen(filename, "r");
    assert(f != NULL);

    float n1;
    for (size_t i = 0; i < img->size_y; ++i) {
        for (size_t j = 0; j < img->size_x; ++j) {
            assert(fscanf(f, "%f,\t", &n1) != EOF);
            *(img->data + i * img->stride_y + j) = n1;
        }
    }
    fclose(f);
}

void set_tile_size(Image * img, size_t tile_size_x, size_t tile_size_y)
{
    assert(tile_size_x <= img->size_x && "tile width <= image width");
    assert(tile_size_y <= img->size_y && "tile height <= image height");
    assert(img->size_x % tile_size_x == 0 && "image width % tile width == 0");
    assert(img->size_y % tile_size_y == 0 && "image height % tile height == 0");

    img->tile_size_x = tile_size_x;
    img->tile_size_y = tile_size_y;

    img->tiles_per_width = img->size_x / img->tile_size_x;
    img->tiles_per_height = img->size_y / img->tile_size_y;

    img->num_tiles = img->tiles_per_width * img->tiles_per_height;
}

void get_tile(const Image * img, Tile * tile, size_t x, size_t y)
{
    tile->size_x = img->tile_size_x;
    tile->size_y = img->tile_size_y;
    tile->stride_y = img->stride_y;
    tile->data = img->data + y * tile->size_y * tile->stride_y + x * tile->size_x;
}

void free_image(Image * img)
{
    #ifdef NO_SSE
        free(img->data);
    #else
        _mm_free(img->data);
    #endif
}

void init_chunk(Chunk * chunk, const Image * img, const BandsThreadingInfo * info, size_t tid)
{
    chunk->num_tiles = (img->num_tiles + info->num_sockets - 1) / info->num_sockets;
    chunk->start_index = tid * chunk->num_tiles;
    chunk->end_index = (tid + 1) * chunk->num_tiles;
        if(chunk->end_index > img->num_tiles)
            chunk->end_index = img->num_tiles;
}

void allocate_bands(Bands * bands, const Image * img)
{
    bands->size_x = img->size_x >> 1;
    bands->size_y = img->size_y >> 1;

    bands->stride_y = bands->size_x + BANDS_ROW_PADDING;
    size_t byte_size = bands->size_y * bands->stride_y * sizeof(float) + LOCA_BAND_PADDING;

    bands->band_size_x = img->tile_size_x >> 1;
    bands->band_size_y = img->tile_size_y >> 1;
    #ifdef NO_SSE
        float * mem = memalign(16, byte_size * 4);
    #else
        float * mem = _mm_malloc(byte_size * 4, 16);
    #endif

    assert(mem != NULL);
        
    bands->LL = mem + 0 * byte_size / sizeof(float);
    bands->HL = mem + 1 * byte_size / sizeof(float);
    bands->LH = mem + 2 * byte_size / sizeof(float);
    bands->HH = mem + 3 * byte_size / sizeof(float);
}

void clear_bands(Bands * bands)
{
    memset(bands->LL, 0, bands->size_y * bands->stride_y * 4 * sizeof(float));
}

void get_tile_bands(const Bands * bands, TileBands * tile_bands, size_t x, size_t y)
{
    tile_bands->size_x = bands->band_size_x;
    tile_bands->size_y = bands->band_size_y;
    tile_bands->stride_y = bands->stride_y;

    tile_bands->LL = bands->LL + y * bands->band_size_y * bands->stride_y + x * bands->band_size_x;
    tile_bands->HL = bands->HL + y * bands->band_size_y * bands->stride_y + x * bands->band_size_x;
    tile_bands->LH = bands->LH + y * bands->band_size_y * bands->stride_y + x * bands->band_size_x;
    tile_bands->HH = bands->HH + y * bands->band_size_y * bands->stride_y + x * bands->band_size_x;
}

void free_bands(Bands * bands)
{
    #ifdef NO_SSE
        free(bands->LL);
    #else
        _mm_free(bands->LL);
    #endif
}

void allocate_bands_threading_info(BandsThreadingInfo * info, size_t num_threads)
{
#ifdef NO_SSE
    size_t * mem = memalign(16, 2 * num_threads * sizeof(size_t));
#else
    size_t * mem = _mm_malloc(2 * num_threads * sizeof(size_t), 16);
#endif
    
    assert(mem != NULL);
    
    info->band_start_y = mem;
    info->band_end_y   = mem + num_threads;
}

void init_bands_threading_info(BandsThreadingInfo * info, const Bands * bands, size_t num_sockets, size_t num_threads)
{
    info->num_sockets = num_sockets;
    info->num_threads = num_threads;

    double band_chunk_y = ((double) bands->band_size_y / (double) num_threads);
    for(size_t tid = 0; tid < num_threads; ++tid){
        info->band_start_y[tid] = (size_t) floor((double)tid * band_chunk_y);
        info->band_end_y[tid] = (size_t) floor((double)(tid + 1) * band_chunk_y);
        if (tid == num_threads - 1) {
            info->band_end_y[tid] = bands->band_size_y;
        }
    }
}

void free_bands_threading_info(BandsThreadingInfo * info)
{
#ifdef NO_SSE
    free(info->band_start_y);
#else
    _mm_free(info->band_start_y);
#endif
}

long long gettimer()
{
    struct timespec t;
    if( -1 == clock_gettime(CLOCK_MONOTONIC_RAW, &t) )
        abort();
    return t.tv_sec * 1000000000LL + t.tv_nsec;
}

void __attribute__((optimize("O0"))) flush_cache()
{
    const size_t allocation_size = 32*1024*1024;
#ifdef NO_SSE
    char *p = (char *) memalign(16, allocation_size);
#else
    char *p = (char *) _mm_malloc(allocation_size, 16);
#endif
    
    assert(p != NULL);

    const char *cp = (const char *)p;
    size_t i = 0;

    if (p == NULL || allocation_size <= 0)
            return;

    for (i = 0; i < allocation_size; i += CACHE_LINE_SIZE) {
        __asm volatile("clflush (%0)\n\t"
                        :
                        : "r"(&cp[i])
                        : "memory");
    }

    #ifndef __MIC__
        __asm volatile("sfence\n\t"
                     :
                     :
                     : "memory");
#ifdef NO_SSE
        free(p);
#else
        _mm_free(p);
#endif
    #endif
}

int compare_times(const void *p1, const void * p2)
{
    long long t1 = *(long long*) p1;
    long long t2 = *(long long*) p2;

    long long diff = t1 - t2;
    if(diff < 0)
        return -1;
    if(diff > 0)
        return 1;
    return 0;
}

