#include "BenchmarkD3D11Schemes.h"
#include "DWTSchemeD3D11Deferred.h"
#include "DXHelper.h"
#include "Application.h"

#include <iostream>
#include <fstream>
#include <time.h>

using Microsoft::WRL::ComPtr;
using namespace std;

BenchmarkD3D11Schemes::BenchmarkD3D11Schemes( AppFrameD3D11 * applicationFrame, ComPtr<ID3D11Device>& device, ComPtr<ID3D11DeviceContext>& context, ComPtr<IDXGISwapChain>& swapChain ) :
    applicationFrame( applicationFrame ),
    device( device ),
    context( context ),
    swapChain( swapChain ) {
        // Debug print
        {
            wcout << "Direct3D 11" << endl;
            wcout << "Adapter: " << applicationFrame->GetAdapterDesc().Description << endl;
            wcout << "Dedicated Video Memory: " << applicationFrame->GetAdapterDesc().DedicatedVideoMemory / 1024 / 1024 << " MB" << endl;
        }

        // Set max resolution
        {
            wcout << "Available resolutions:" << endl;
            unsigned __int64 memory = applicationFrame->GetAdapterDesc().DedicatedVideoMemory / 1024 / 1024;
            for ( unsigned int i = 0; i < resolutionsCount; i++ ) {
                if ( memory * maxVRAMUsage > 2 * resolutions[ i ] * resolutions[ i ] * 4 * 4 / 1024 / 1024 ) {
                    resolutionsCountVRAMLimited = i + 1;
                    wcout << "\t" << i + 1 << ": " << resolutions[ i ] << endl;
                }
                else {
                    break;
                }
            }
            unsigned int maxResolutionIndex = 0;
            while ( maxResolutionIndex < 1 || maxResolutionIndex > resolutionsCountVRAMLimited ) {
                wcout << "Select max resolution [1.." << resolutionsCountVRAMLimited << "]: ";
                cin >> maxResolutionIndex;
            }
            resolutionsCountVRAMLimited = maxResolutionIndex;
            resolutionIndex = resolutionsCountVRAMLimited - 1;
        }

        // Set schemes count
        schemesCount = immediateSchemes.count;

        // Initialize Direct3D for benchmark
        InitD3D();

        // Debug print
        {
            wcout << "DWT schemes for benchmarking:" << endl;
        }

        // Initialize Direct3D for schemes
        DWTSchemeD3D11Deferred::InitD3D( device, context, constantBuffer, vsMain, samplerState );

        // Initialize immediate schemes
        InitializeContainter<DWTSchemeD3D11>( immediateSchemes, schemes53File, schemes97File, schemes137File );

        // Initialize time samplers
        immediateTimeSamplers.resize( resolutionsCountVRAMLimited );

        // Initialize profiler
        profiler.reset( new ProfilerD3D11( 5, device, context ) );

        // Debug print
        {
            wcout << "Benchmark ... initialized" << endl << endl;
        }

        // Setup initial scheme
        SetupScheme();

        // No swap chain (window) -> start benchmark immediately
        if ( swapChain == nullptr ) {
            Start();
        }

}

void BenchmarkD3D11Schemes::InitD3D() {
    HRESULT hr = S_OK;
    // Create a render target view
    if ( swapChain != nullptr ) {
        ComPtr<ID3D11Texture2D> backBuffer;

        hr = swapChain->GetBuffer( 0, __uuidof( ID3D11Texture2D ), reinterpret_cast< void** >( backBuffer.GetAddressOf() ) );
        ThrowIfFailed( hr, L"IDXGISwapChain::GetBuffer error" );

        hr = device->CreateRenderTargetView( backBuffer.Get(), nullptr, renderTargetView.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreateRenderTargetView error" );
    }

    // Compile shaders
    {
        ComPtr<ID3DBlob> pVSBlob;
        ComPtr<ID3DBlob> pPSBlob;

        CompileShaderFromFile( shaderFile.c_str(), "VS", "vs_5_0", &pVSBlob );
        hr = device->CreateVertexShader( pVSBlob->GetBufferPointer(), pVSBlob->GetBufferSize(), nullptr, vsMain.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreateVertexShader error" );

        CompileShaderFromFile( shaderFile.c_str(), "CopyFloat4", "ps_5_0", &pPSBlob );
        hr = device->CreatePixelShader( pPSBlob->GetBufferPointer(), pPSBlob->GetBufferSize(), nullptr, psCopyFloat4.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreatePixelShader error" );

        CompileShaderFromFile( shaderFile.c_str(), "Copy4Tex", "ps_5_0", &pPSBlob );
        hr = device->CreatePixelShader( pPSBlob->GetBufferPointer(), pPSBlob->GetBufferSize(), nullptr, psCopy4Tex.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreatePixelShader error" );

        CompileShaderFromFile( shaderFile.c_str(), "RandomFloat4", "ps_5_0", &pPSBlob );
        hr = device->CreatePixelShader( pPSBlob->GetBufferPointer(), pPSBlob->GetBufferSize(), nullptr, psRandomFloat4.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreatePixelShader error" );

        CompileShaderFromFile( shaderFile.c_str(), "Random4Tex", "ps_5_0", &pPSBlob );
        hr = device->CreatePixelShader( pPSBlob->GetBufferPointer(), pPSBlob->GetBufferSize(), nullptr, psRandom4Tex.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreatePixelShader error" );
    }

    // Describe and create the constant buffer
    {
        D3D11_BUFFER_DESC bd = {};
        bd.ByteWidth = sizeof( ConstantBuffer );
        bd.Usage = D3D11_USAGE_DEFAULT;
        bd.BindFlags = D3D11_BIND_CONSTANT_BUFFER;

        hr = device->CreateBuffer( &bd, nullptr, constantBuffer.GetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreateBuffer error" );
    }

    // Describe and create the sampler
    {
        D3D11_SAMPLER_DESC samplerDesc = {};
        samplerDesc.Filter = D3D11_FILTER_MIN_MAG_MIP_POINT;
        samplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_MIRROR;
        samplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_MIRROR;
        samplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_MIRROR;
        samplerDesc.ComparisonFunc = D3D11_COMPARISON_ALWAYS;

        hr = device->CreateSamplerState( &samplerDesc, samplerState.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"ID3D11Device::CreateSamplerState error" );
    }
}

void BenchmarkD3D11Schemes::PreviousScheme() {
    if ( benchmarking ) {
        return;
    }
    --schemeIndex;
    if ( schemesCount <= schemeIndex ) {
        schemeIndex = schemesCount - 1;
    }
    SetupScheme();
}

void BenchmarkD3D11Schemes::PreviousResolution() {
    if ( benchmarking ) {
        return;
    }
    --resolutionIndex;
    if ( resolutionsCountVRAMLimited <= resolutionIndex ) {
        resolutionIndex = resolutionsCountVRAMLimited - 1;
    }
    SetupScheme();
}

void BenchmarkD3D11Schemes::NextScheme() {
    if ( benchmarking ) {
        return;
    }
    ++schemeIndex;
    if ( schemesCount <= schemeIndex ) {
        schemeIndex = 0;
    }
    SetupScheme();
}

void BenchmarkD3D11Schemes::NextResolution() {
    if ( benchmarking ) {
        return;
    }
    ++resolutionIndex;
    if ( resolutionsCountVRAMLimited <= resolutionIndex ) {
        resolutionIndex = 0;
    }
    SetupScheme();
}

void BenchmarkD3D11Schemes::GetCurrentScheme( DWTSchemeD3D11** scheme, Sampler** sampler ) {
    ( *scheme ) = static_cast< DWTSchemeD3D11* >( immediateSchemes[ schemeIndex ].get() );
    ( *sampler ) = &immediateTimeSamplers[ resolutionIndex ][ schemeIndex ];

}

void BenchmarkD3D11Schemes::Start() {
    if ( benchmarking ) {
        return;
    }
    benchmarking = true;
    iterationIndex = 0;
    schemeIndex = 0;
    resolutionIndex = 0;
    SetupScheme();
}

void BenchmarkD3D11Schemes::SetupScheme() {
    // Select curent scheme
    DWTSchemeD3D11* currentScheme = nullptr;
    Sampler* currentSampler = nullptr;
    GetCurrentScheme( &currentScheme, &currentSampler );

    // Setup current scheme
    currentScheme->Init( resolutions[ resolutionIndex ] );

    // Debug print
    {
        wcout << currentScheme->GetName() << " (" << schemeIndex + 1 << "/" << schemesCount << ") @ " << resolutions[ resolutionIndex ] << " (" << resolutionIndex + 1 << "/" << resolutionsCountVRAMLimited << ")" << endl;
    }
}

void BenchmarkD3D11Schemes::Run() {
    // Select curent scheme
    DWTSchemeD3D11* currentScheme = nullptr;
    Sampler* currentSampler = nullptr;
    GetCurrentScheme( &currentScheme, &currentSampler );

    // Empty resources
    void * emptyResources[ 4 ] = { nullptr, nullptr, nullptr, nullptr };

    // Setup (before run)
    {
        // Update constant buffer
        ConstantBuffer data;
        data.textureRes = DirectX::XMFLOAT4( static_cast< float >( resolutions[ resolutionIndex ] ), 1.0f / resolutions[ resolutionIndex ], 2.0f / resolutions[ resolutionIndex ], 3.0f / resolutions[ resolutionIndex ] );
        context->UpdateSubresource( constantBuffer.Get(), 0, 0, &data, 0, 0 );

        // Set primitive topology
        context->IASetPrimitiveTopology( D3D11_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP );

        // Set vertex shader
        context->VSSetShader( vsMain.Get(), nullptr, 0 );
        context->VSSetConstantBuffers( 0, 1, constantBuffer.GetAddressOf() );

        // Viewport
        D3D11_VIEWPORT vp = {};
        vp.Width = static_cast< FLOAT >( currentScheme->GetResolution() );
        vp.Height = vp.Width;
        vp.MaxDepth = 1.0f;
        context->RSSetViewports( 1, &vp );

        // Sampler
        context->PSSetSamplers( 0, 1, samplerState.GetAddressOf() );

        // Constant buffer
        context->PSSetConstantBuffers( 0, 1, constantBuffer.GetAddressOf() );

        // Fill source data
        {
            if ( currentScheme->IsFloat4() ) {
                context->OMSetRenderTargets( 1, currentScheme->GetInputRTVs().data(), nullptr );
                context->PSSetShader( psRandomFloat4.Get(), nullptr, 0 );
            }
            else {
                context->OMSetRenderTargets( 4, currentScheme->GetInputRTVs().data(), nullptr );
                context->PSSetShader( psRandom4Tex.Get(), nullptr, 0 );
            }
            context->Draw( 4, 0 );
            context->OMSetRenderTargets( 4, reinterpret_cast< ID3D11RenderTargetView** >( emptyResources ), nullptr );
        }
    }

    double seconds = 0.0;

    // Benchmark
    {
        // Begin profiling
        profiler.get()->Begin();
        // Run scheme
        currentScheme->Run( 1 );
        // End profiling
        seconds = profiler.get()->End();
    }

    // Benchmark controller
    if ( benchmarking ) {
        // Increment iteration index
        ++iterationIndex;
        // Debug print
        {
            float percentage = ( static_cast< float >( ( resolutionIndex * schemesCount ) + schemeIndex ) / static_cast< float >( schemesCount * resolutionsCountVRAMLimited ) ) +
                ( static_cast< float >( currentSampler->NumSamples() ) / static_cast< float >( requiedTimeSamples ) / static_cast< float >( schemesCount * resolutionsCountVRAMLimited ) );
            wcout << "[" << static_cast< int >( ( runIndex + 1 ) ) << "/" << requiedRuns << " " << static_cast< int >( percentage * 100 ) << "%]          \r" << flush;
        }
        // Check elapsed time
        if ( seconds > 0.0 && iterationIndex > ignoreIterations ) {
            // Save elapsed time (with latency)
            currentSampler->Sample( seconds );
            if ( currentSampler->NumSamples() >= requiedTimeSamples * ( runIndex + 1 ) ) {
                // First iteration (zero index)
                iterationIndex = 0;
                // Next scheme
                ++schemeIndex;
                if ( schemeIndex >= schemesCount ) {
                    // First scheme for next resolution (zero index)
                    schemeIndex = 0;
                    // Next resolution
                    ++resolutionIndex;
                    if ( resolutionIndex >= resolutionsCountVRAMLimited ) {
                        // First resolution (zero index)
                        resolutionIndex = 0;
                        ++runIndex;
                        if ( runIndex >= requiedRuns ) {
                            // Done
                            PrintResults();
                            // Close frame and application
                            applicationFrame->Close();
                        }
                    }
                }
                SetupScheme();
                return;
            }
        }
    }
    else {
        // Debug print
        {
            wcout << "[" << seconds * 1000 << "ms]          \r" << flush;
        }
    }

    if ( swapChain != nullptr ) {
        // Setup (after run)
        {
            // Set primitive topology
            context->IASetPrimitiveTopology( D3D11_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP );

            // Set vertex shader
            context->VSSetShader( vsMain.Get(), nullptr, 0 );
            context->VSSetConstantBuffers( 0, 1, constantBuffer.GetAddressOf() );

            // Viewport
            D3D11_VIEWPORT vp = {};
            vp.Width = static_cast< FLOAT >( currentScheme->GetResolution() );
            vp.Height = vp.Width;
            vp.MaxDepth = 1.0f;
            context->RSSetViewports( 1, &vp );

            // Sampler
            context->PSSetSamplers( 0, 1, samplerState.GetAddressOf() );
        }

        // Draw preview output
        {
            if ( currentScheme->IsFloat4() ) {
                context->PSSetShaderResources( 0, 1, currentScheme->GetOutputSRVs().data() );
                context->PSSetShader( psCopyFloat4.Get(), nullptr, 0 );
            }
            else {
                context->PSSetShaderResources( 0, 4, currentScheme->GetOutputSRVs().data() );
                context->PSSetShader( psCopy4Tex.Get(), nullptr, 0 );
            }
            context->OMSetRenderTargets( 1, renderTargetView.GetAddressOf(), nullptr );
            context->Draw( 4, 0 );

            context->PSSetShaderResources( 0, 4, reinterpret_cast< ID3D11ShaderResourceView** >( emptyResources ) );
            context->OMSetRenderTargets( 4, reinterpret_cast< ID3D11RenderTargetView** >( emptyResources ), nullptr );
        }

        // Present
        swapChain->Present( 0, 0 );
    }
    else {
        context->Flush();
    }
}

void BenchmarkD3D11Schemes::PrintResults() {
    DWORD dw;

    // Get current path
    WCHAR cCurrentPath[ FILENAME_MAX ];
    dw = GetCurrentDirectory( sizeof( cCurrentPath ), cCurrentPath );
    if ( dw == 0 ) {
        throw;
    }

    // Get adapter description
    DXGI_ADAPTER_DESC adapter = applicationFrame->GetAdapterDesc();

    // Get time string
    time_t rawTime;
    time( &rawTime );
    char timeStringBuffer[ 80 ];
    tm localTime;
    localtime_s( &localTime, &rawTime );
    strftime( timeStringBuffer, 80, "%F %H-%M-%S", &localTime );

    string timeString;
    timeString.append( timeStringBuffer );

    wstring wTimeString;
    wTimeString.assign( timeString.begin(), timeString.end() );

    // Prepare directory name/path
    std::wstring dir;
    dir.append( cCurrentPath );
    dir.append( L"\\" );
    dir.append( adapter.Description );
    dir.append( L" " );
    dir.append( wTimeString );

    // Create directory
    dw = CreateDirectory( dir.c_str(), nullptr );
    if ( dw == 0 ) {
        throw;
    }

    // GBPS plot
    {
        // Create files
        ofstream outputFile53GBPSMedian( dir + L"\\CDF53gb.dat", ios::out );
        ofstream outputFile97GBPSMedian( dir + L"\\CDF97gb.dat", ios::out );
        ofstream outputFile137GBPSMedian( dir + L"\\CDF137gb.dat", ios::out );

        // Divide schemes
        int cdf53 = 0;
        int cdf97 = 22;
        int dd137 = 56;

        // Write headers
        outputFile53GBPSMedian << "resolution";
        outputFile97GBPSMedian << "resolution";
        outputFile137GBPSMedian << "resolution";

        // CDF 5/3 header
        for ( int i = cdf53; i < cdf97; i++ ) {
            wstring name = immediateSchemes[ i ]->GetName();
            outputFile53GBPSMedian << "\t" << std::string( name.begin(), name.end() );
        }
        // CDF 9/7 header
        for ( int i = cdf97; i < dd137; i++ ) {
            wstring name = immediateSchemes[ i ]->GetName();
            outputFile97GBPSMedian << "\t" << std::string( name.begin(), name.end() );
        }
        // DD 13/7 header
        for ( int i = dd137; i < immediateSchemes.count; i++ ) {
            wstring name = immediateSchemes[ i ]->GetName();
            outputFile137GBPSMedian << "\t" << std::string( name.begin(), name.end() );
        }

        outputFile53GBPSMedian << endl;
        outputFile97GBPSMedian << endl;
        outputFile137GBPSMedian << endl;

        for ( unsigned int i = 0; i < resolutionsCountVRAMLimited; i++ ) {
            // # (resolution)
            unsigned int resolution = resolutions[ i ] * resolutions[ i ] * 4;
            outputFile53GBPSMedian << resolution;
            outputFile97GBPSMedian << resolution;
            outputFile137GBPSMedian << resolution;

            // CDF 5/3 data
            for ( int j = cdf53; j < cdf97; j++ ) {
                outputFile53GBPSMedian << "\t" << GBPSMedian( immediateTimeSamplers[ i ][ j ].GetResults(), resolution );
            }
            // CDF 9/7 data
            for ( int j = cdf97; j < dd137; j++ ) {
                outputFile97GBPSMedian << "\t" << GBPSMedian( immediateTimeSamplers[ i ][ j ].GetResults(), resolution );
            }
            // DD 13/7 data
            for ( int j = dd137; j < immediateSchemes.count; j++ ) {
                outputFile137GBPSMedian << "\t" << GBPSMedian( immediateTimeSamplers[ i ][ j ].GetResults(), resolution );
            }

            // Line ends
            outputFile53GBPSMedian << std::endl;
            outputFile97GBPSMedian << std::endl;
            outputFile137GBPSMedian << std::endl;
        }
        outputFile53GBPSMedian.close();
        outputFile97GBPSMedian.close();
        outputFile137GBPSMedian.close();
    }

    // Full log
    {
        wofstream outputFileRAW( dir + L"\\raw.dat", ios::out );

        // Write headers
        outputFileRAW << "Method\tSteps\tInputResolutionSquareRoot[px]\tInputResolution[px]\tInputSize[B]\tMin[ms]\tMax[ms]\tAverage[ms]\tMedian[ms]\tSamplesCount\tSamples[ms]..." << endl;

        // For each scheme
        for ( unsigned int j = 0; j < schemesCount; j++ ) {
            // For each resolution
            for ( unsigned int i = 0; i < resolutionsCountVRAMLimited; i++ ) {
                unsigned int resolution = resolutions[ i ] * resolutions[ i ];
                unsigned int bytes = resolution * 4;
                Sampler::Results results = immediateTimeSamplers[ i ][ j ].GetResults();

                outputFileRAW << immediateSchemes[ j ]->GetName() << "\t";
                outputFileRAW << immediateSchemes[ j ]->GetSteps() << "\t";
                outputFileRAW << resolutions[ i ] << "\t";
                outputFileRAW << resolution << "\t";
                outputFileRAW << bytes << "\t";
                outputFileRAW << results.min * 1000 << "\t";
                outputFileRAW << results.max * 1000 << "\t";
                outputFileRAW << results.average * 1000 << "\t";
                outputFileRAW << results.median * 1000 << "\t";
                outputFileRAW << results.measurements.size();

                for ( unsigned int k = 0; k < results.measurements.size(); k++ ) {
                    outputFileRAW << "\t" << results.measurements[ k ] * 1000;
                }

                outputFileRAW << endl;
            }
        }
        outputFileRAW.close();
    }
}