#include "DWTSchemeD3D11.h"
#include "DXHelper.h"

using namespace std;
using namespace Microsoft::WRL;

DWTSchemeD3D11::D3DResources                DWTSchemeD3D11::resources = {};
ComPtr<ID3D11Device>                        DWTSchemeD3D11::device;
ComPtr<ID3D11DeviceContext>                 DWTSchemeD3D11::context;

DWTSchemeD3D11::DWTSchemeD3D11( wstring name, int steps, wstring shaderFile, RENDER_TARGET_FLAG* rtFlags ) :
    DWTScheme( name, steps ),
    float4( rtFlags ? false : true ) {
    // Save render target flags
    if ( rtFlags ) {
        renderTargetFlags.reserve( steps );
        for ( int i = 0; i < steps; i++ ) {
            renderTargetFlags.push_back( rtFlags[ i ] );
        }
    }

    // Reserve/resize capacity of vectors
    shaderResources.resize( steps );
    renderTargets.resize( steps );
    shaders.resize( steps );

    // Compile shaders
    for ( int i = 0; i < steps; i++ ) {
        HRESULT hr;
        // Prepare shader entry point name
        string entryPoint( name.begin(), name.end() );
        entryPoint.append( stepMehodSuffix );
        entryPoint.append( std::to_string( i + 1 ) );

        // Compile the pixel shader
        ComPtr<ID3DBlob> pPSBlob;
        CompileShaderFromFile( shaderFile.c_str(), entryPoint.c_str(), "ps_5_0", pPSBlob.GetAddressOf() );

        // Create the pixel shader
        hr = device->CreatePixelShader( pPSBlob->GetBufferPointer(), pPSBlob->GetBufferSize(), 0, shaders[ i ].GetAddressOf() );
        ThrowIfFailed( hr, L"Device::CreatePixelShader error" );
    }

    // Prepare resources
    if ( float4 ) {
        // 1x float4
        outputTextures.reserve( 1 );
        outputShaderResources.reserve( 1 );

        for ( int i = 0; i < steps; i++ ) {
            if ( i % 2 ) {
                shaderResources[ i ].push_back( PtrToComPtr( resources.swap1SRV_RGBA ) );
                renderTargets[ i ].push_back( PtrToComPtr( resources.swap0RTV_RGBA ) );
            }
            else {
                shaderResources[ i ].push_back( PtrToComPtr( resources.swap0SRV_RGBA ) );
                renderTargets[ i ].push_back( PtrToComPtr( resources.swap1RTV_RGBA ) );
            }
        }
        outputShaderResources.push_back( PtrToComPtr( ( steps % 2 ) ? resources.swap1SRV_RGBA : resources.swap0SRV_RGBA ) );
        outputTextures.push_back( PtrToComPtr( ( steps % 2 ) ? resources.swap1Tex_RGBA : resources.swap0Tex_RGBA ) );
    }
    else {
        // 4x float1
        outputShaderResources.reserve( 4 );
        outputTextures.reserve( 4 );

        int r = 0, g = 0, b = 0, a = 0;
        for ( int i = 0; i < steps; i++ ) {
            shaderResources[ i ].reserve( 4 );
            shaderResources[ i ].push_back( PtrToComPtr( ( r % 2 ) ? resources.swap1SRV_R : resources.swap0SRV_R ) );
            shaderResources[ i ].push_back( PtrToComPtr( ( g % 2 ) ? resources.swap1SRV_G : resources.swap0SRV_G ) );
            shaderResources[ i ].push_back( PtrToComPtr( ( b % 2 ) ? resources.swap1SRV_B : resources.swap0SRV_B ) );
            shaderResources[ i ].push_back( PtrToComPtr( ( a % 2 ) ? resources.swap1SRV_A : resources.swap0SRV_A ) );
            renderTargets[ i ] = MultipleRenderTargets( i, renderTargetFlags[ i ], r, g, b, a );
        }
        outputShaderResources.push_back( PtrToComPtr( ( r % 2 ) ? resources.swap1SRV_R : resources.swap0SRV_R ) );
        outputShaderResources.push_back( PtrToComPtr( ( g % 2 ) ? resources.swap1SRV_G : resources.swap0SRV_G ) );
        outputShaderResources.push_back( PtrToComPtr( ( b % 2 ) ? resources.swap1SRV_B : resources.swap0SRV_B ) );
        outputShaderResources.push_back( PtrToComPtr( ( a % 2 ) ? resources.swap1SRV_A : resources.swap0SRV_A ) );
        outputTextures.push_back( PtrToComPtr( ( r % 2 ) ? resources.swap1Tex_R : resources.swap0Tex_R ) );
        outputTextures.push_back( PtrToComPtr( ( g % 2 ) ? resources.swap1Tex_G : resources.swap0Tex_G ) );
        outputTextures.push_back( PtrToComPtr( ( b % 2 ) ? resources.swap1Tex_B : resources.swap0Tex_B ) );
        outputTextures.push_back( PtrToComPtr( ( a % 2 ) ? resources.swap1Tex_A : resources.swap0Tex_A ) );
    }
}

DWTSchemeD3D11::~DWTSchemeD3D11() {
}

void DWTSchemeD3D11::Run( unsigned int iterations ) {
    // Get device context
    ID3D11DeviceContext * deviceContext = context.Get();

    // Empty resources
    void * emptyResources[ 4 ] = { nullptr, nullptr, nullptr, nullptr };

    // Render targets
    ID3D11RenderTargetView * rtv[ 4 ];

    // Shader resource views
    ID3D11ShaderResourceView * srv[ 4 ];

    // Draw
    while ( iterations-- > 0 ) {
        for ( int i = 0; i < steps; i++ ) {
            for ( size_t j = 0; j < renderTargets[ i ].size(); j++ ) {
                rtv[ j ] = renderTargets[ i ][ j ]->Get();
            }
            for ( int j = 0; j < ( float4 ? 1 : 4 ); j++ ) {
                srv[ j ] = shaderResources[ i ][ j ]->Get();
            }
            deviceContext->OMSetRenderTargets( renderTargets[ i ].size(), rtv, nullptr );
            deviceContext->PSSetShaderResources( 0, float4 ? 1 : 4, srv );
            deviceContext->PSSetShader( shaders[ i ].Get(), nullptr, 0 );
            deviceContext->Draw( 4, 0 );
            deviceContext->PSSetShaderResources( 0, 4, reinterpret_cast< ID3D11ShaderResourceView** >( emptyResources ) );
            deviceContext->OMSetRenderTargets( 4, reinterpret_cast< ID3D11RenderTargetView** >( emptyResources ), nullptr );
        }
    }
}

void DWTSchemeD3D11::InitD3D( Microsoft::WRL::ComPtr<ID3D11Device>& device, Microsoft::WRL::ComPtr<ID3D11DeviceContext>& context ) {
    DWTSchemeD3D11::device = device;
    DWTSchemeD3D11::context = context;
}

vector<ID3D11RenderTargetView*> DWTSchemeD3D11::GetInputRTVs() {
    if ( IsFloat4() ) {
        return vector<ID3D11RenderTargetView*>{ resources.swap0RTV_RGBA.Get() };
    }
    return vector<ID3D11RenderTargetView*>{ resources.swap0RTV_R.Get(), resources.swap0RTV_G.Get(), resources.swap0RTV_B.Get(), resources.swap0RTV_A.Get() };
}

vector<ID3D11ShaderResourceView*> DWTSchemeD3D11::GetInputSRVs() {
    if ( IsFloat4() ) {
        return vector<ID3D11ShaderResourceView*>{ resources.swap0SRV_RGBA.Get() };
    }
    return vector<ID3D11ShaderResourceView*>{ resources.swap0SRV_R.Get(), resources.swap0SRV_G.Get(), resources.swap0SRV_B.Get(), resources.swap0SRV_A.Get() };
}

vector<ID3D11ShaderResourceView*> DWTSchemeD3D11::GetOutputSRVs() {
    if ( IsFloat4() ) {
        return vector<ID3D11ShaderResourceView*>{ outputShaderResources[ 0 ]->Get() };
    }
    return vector<ID3D11ShaderResourceView*>{ outputShaderResources[ 0 ]->Get(), outputShaderResources[ 1 ]->Get(), outputShaderResources[ 2 ]->Get(), outputShaderResources[ 3 ]->Get() };
}

void DWTSchemeD3D11::Init( unsigned int resolution ) {
    resources.InitTextures( resolution, float4, !float4, false );
}

vector<ComPtr<ID3D11RenderTargetView>*> DWTSchemeD3D11::MultipleRenderTargets( int step, RENDER_TARGET_FLAG renderTargetFlag, int & r, int & g, int & b, int & a ) {
    vector<ComPtr<ID3D11RenderTargetView>*> multipleRenderTargets;
    multipleRenderTargets.reserve( 4 );
    if ( renderTargetFlag & R ) {
        multipleRenderTargets.push_back( PtrToComPtr( ( r % 2 ) ? resources.swap0RTV_R : resources.swap1RTV_R ) );
        ++r;
    }
    if ( renderTargetFlag & G ) {
        multipleRenderTargets.push_back( PtrToComPtr( ( g % 2 ) ? resources.swap0RTV_G : resources.swap1RTV_G ) );
        ++g;
    }
    if ( renderTargetFlag & B ) {
        multipleRenderTargets.push_back( PtrToComPtr( ( b % 2 ) ? resources.swap0RTV_B : resources.swap1RTV_B ) );
        ++b;
    }
    if ( renderTargetFlag & A ) {
        multipleRenderTargets.push_back( PtrToComPtr( ( a % 2 ) ? resources.swap0RTV_A : resources.swap1RTV_A ) );
        ++a;
    }
    return multipleRenderTargets;
}

void DWTSchemeD3D11::D3DResources::InitTextures( unsigned int resolution, bool float4, bool tex4, bool fill ) {
    // Release unused resources
    ReleaseTextures();

    // Texture data
    float* texDataR = nullptr;
    float* texDataG = nullptr;
    float* texDataB = nullptr;
    float* texDataA = nullptr;
    float* texDataRGBA = nullptr;

    try {
        if ( tex4 && fill ) {
            texDataR = new float[ resolution * resolution ];
            texDataG = new float[ resolution * resolution ];
            texDataB = new float[ resolution * resolution ];
            texDataA = new float[ resolution * resolution ];
        }
        if ( float4 && fill ) {
            texDataRGBA = new float[ resolution * resolution * 4 ];
        }

        if ( fill ) {
            // Fill textures with random data
            for ( unsigned int i = 0; i < resolution * resolution; i++ ) {
                int row = ( i / resolution ) * resolution * 4;
                float r = rand() / ( float ) RAND_MAX;
                float g = rand() / ( float ) RAND_MAX;
                float b = rand() / ( float ) RAND_MAX;
                float a = rand() / ( float ) RAND_MAX;

                if ( float4 ) {
                    texDataRGBA[ i * 4 + 0 ] = r;
                    texDataRGBA[ i * 4 + 1 ] = g;
                    texDataRGBA[ i * 4 + 2 ] = b;
                    texDataRGBA[ i * 4 + 3 ] = a;
                }

                if ( tex4 ) {
                    texDataR[ i ] = r;
                    texDataG[ i ] = g;
                    texDataB[ i ] = b;
                    texDataA[ i ] = a;
                }
            }
        }

        // Create textures (4x float1, double buffering)
        if ( tex4 ) {
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap0Tex_R, swap0RTV_R, swap0SRV_R, texDataR, sizeof( float ), 1 );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap0Tex_G, swap0RTV_G, swap0SRV_G, texDataG, sizeof( float ), 1 );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap0Tex_B, swap0RTV_B, swap0SRV_B, texDataB, sizeof( float ), 1 );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap0Tex_A, swap0RTV_A, swap0SRV_A, texDataA, sizeof( float ), 1 );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap1Tex_R, swap1RTV_R, swap1SRV_R );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap1Tex_G, swap1RTV_G, swap1SRV_G );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap1Tex_B, swap1RTV_B, swap1SRV_B );
            InitTexture( resolution, DXGI_FORMAT_R32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap1Tex_A, swap1RTV_A, swap1SRV_A );
        }

        // Create textures (1x float4, double buffering)
        if ( float4 ) {
            InitTexture( resolution, DXGI_FORMAT_R32G32B32A32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap0Tex_RGBA, swap0RTV_RGBA, swap0SRV_RGBA, texDataRGBA, sizeof( float ), 4 );
            InitTexture( resolution, DXGI_FORMAT_R32G32B32A32_FLOAT, D3D11_USAGE_DEFAULT, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET, swap1Tex_RGBA, swap1RTV_RGBA, swap1SRV_RGBA );
        }

        // Set resources 
        this->resourcesVersion++;
        this->float4 = float4;
        this->tex4 = tex4;
        this->textureResolution = resolution;
    }
    catch ( ... ) {
        // Release unused resources
        ReleaseTextures();

        // Release texture data
        delete[] texDataR;
        delete[] texDataG;
        delete[] texDataB;
        delete[] texDataA;
        delete[] texDataRGBA;
        throw;
    }
    // Release texture data
    delete[] texDataR;
    delete[] texDataG;
    delete[] texDataB;
    delete[] texDataA;
    delete[] texDataRGBA;
}

void DWTSchemeD3D11::D3DResources::InitTexture( unsigned int resolution, DXGI_FORMAT format, D3D11_USAGE usage, UINT bindFlags, ComPtr<ID3D11Texture2D>& texture, ComPtr<ID3D11RenderTargetView>& renderTargetView, ComPtr<ID3D11ShaderResourceView>& shaderResourceView, const void * data, int typeSize, int componentsCount ) {
    HRESULT hr;

    // Setup the render target texture description
    D3D11_TEXTURE2D_DESC textureDesc = {};
    textureDesc.Width = resolution;
    textureDesc.Height = resolution;
    textureDesc.MipLevels = 1;
    textureDesc.ArraySize = 1;
    textureDesc.Format = format;
    textureDesc.SampleDesc.Count = 1;
    textureDesc.Usage = usage;
    textureDesc.BindFlags = bindFlags;
    textureDesc.CPUAccessFlags = usage & D3D11_USAGE_DYNAMIC ? D3D11_CPU_ACCESS_WRITE | D3D11_CPU_ACCESS_READ : 0;
    textureDesc.MiscFlags = 0;

    // Create texture
    if ( data != nullptr ) {
        D3D11_SUBRESOURCE_DATA texData = {};
        texData.pSysMem = data;
        texData.SysMemPitch = resolution * componentsCount * typeSize;
        texData.SysMemSlicePitch = resolution * resolution * componentsCount * typeSize;
        hr = device->CreateTexture2D( &textureDesc, &texData, texture.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"Device::CreateTexture2D error" );
    }
    else {
        hr = device->CreateTexture2D( &textureDesc, nullptr, texture.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"Device::CreateTexture2D error" );
    }

    if ( bindFlags & D3D11_BIND_RENDER_TARGET ) {
        // Setup the description of the render target view
        D3D11_RENDER_TARGET_VIEW_DESC renderTargetViewDesc = {};
        renderTargetViewDesc.Format = textureDesc.Format;
        renderTargetViewDesc.ViewDimension = D3D11_RTV_DIMENSION_TEXTURE2D;
        renderTargetViewDesc.Texture2D.MipSlice = 0;

        // Create the render target view
        hr = device->CreateRenderTargetView( texture.Get(), &renderTargetViewDesc, renderTargetView.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"Device::CreateRenderTargetView error" );
    }

    if ( bindFlags & D3D11_BIND_SHADER_RESOURCE ) {
        // Setup the description of the shader resource view
        D3D11_SHADER_RESOURCE_VIEW_DESC shaderResourceViewDesc = {};
        shaderResourceViewDesc.Format = textureDesc.Format;
        shaderResourceViewDesc.ViewDimension = D3D11_SRV_DIMENSION_TEXTURE2D;
        shaderResourceViewDesc.Texture2D.MostDetailedMip = 0;
        shaderResourceViewDesc.Texture2D.MipLevels = 1;

        // Create the shader resource view
        hr = device->CreateShaderResourceView( texture.Get(), &shaderResourceViewDesc, shaderResourceView.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr, L"Device::CreateShaderResourceView error" );
    }
}

void DWTSchemeD3D11::D3DResources::ReleaseTextures() {
    // Release
    float4 = false;
    tex4 = false;
    textureResolution = 0;

    ++resourcesVersion;

    swap0Tex_R.Reset();
    swap0Tex_G.Reset();
    swap0Tex_B.Reset();
    swap0Tex_A.Reset();

    swap0SRV_R.Reset();
    swap0SRV_G.Reset();
    swap0SRV_B.Reset();
    swap0SRV_A.Reset();

    swap0RTV_R.Reset();
    swap0RTV_G.Reset();
    swap0RTV_B.Reset();
    swap0RTV_A.Reset();

    swap1Tex_R.Reset();
    swap1Tex_G.Reset();
    swap1Tex_B.Reset();
    swap1Tex_A.Reset();

    swap1SRV_R.Reset();
    swap1SRV_G.Reset();
    swap1SRV_B.Reset();
    swap1SRV_A.Reset();

    swap1RTV_R.Reset();
    swap1RTV_G.Reset();
    swap1RTV_B.Reset();
    swap1RTV_A.Reset();

    swap0Tex_RGBA.Reset();
    swap0SRV_RGBA.Reset();
    swap0RTV_RGBA.Reset();

    swap1Tex_RGBA.Reset();
    swap1SRV_RGBA.Reset();
    swap1RTV_RGBA.Reset();

    context->ClearState();
    context->Flush();
}
