#include "DWTSchemeD3D11Deferred.h"

using Microsoft::WRL::ComPtr;

ComPtr<ID3D11DeviceContext>        DWTSchemeD3D11Deferred::deferredContext;
ComPtr<ID3D11Buffer>               DWTSchemeD3D11Deferred::constantBuffer;
ComPtr<ID3D11VertexShader>         DWTSchemeD3D11Deferred::vertexShader;
ComPtr<ID3D11SamplerState>         DWTSchemeD3D11Deferred::samplerState;
ComPtr<ID3D11CommandList>          DWTSchemeD3D11Deferred::commandList;

DWTSchemeD3D11Deferred::DWTSchemeD3D11Deferred( std::wstring name, int steps, std::wstring shaderFile, RENDER_TARGET_FLAG * renderTargetFlag ) :
    DWTSchemeD3D11( name, steps, shaderFile, renderTargetFlag ) {
}

void DWTSchemeD3D11Deferred::Init( unsigned int resolution ) {
    Init( resolution, 1 );
}

void DWTSchemeD3D11Deferred::Init( unsigned int resolution, unsigned int iterations ) {
    // Release command list
    commandList.Reset();

    // Initialize base
    DWTSchemeD3D11::Init( resolution );

    // Get device context
    ID3D11DeviceContext * deviceContext = deferredContext.Get();

    // Empty resources
    void * emptyResources[ 4 ] = { nullptr, nullptr, nullptr, nullptr };

    // Render targets
    ID3D11RenderTargetView * rtv[ 4 ];

    // Shader resource views
    ID3D11ShaderResourceView * srv[ 4 ];

    // Constant buffer, vertex shader, input assembler, sampler state
    deviceContext->VSSetConstantBuffers( 0, 1, constantBuffer.GetAddressOf() );
    deviceContext->IASetPrimitiveTopology( D3D11_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP );
    deviceContext->VSSetShader( vertexShader.Get(), nullptr, 0 );
    deviceContext->PSSetSamplers( 0, 1, samplerState.GetAddressOf() );

    // Viewport
    D3D11_VIEWPORT vp = {};
    vp.Width = static_cast< FLOAT >( resolution );
    vp.Height = vp.Width;
    vp.MaxDepth = 1.0f;
    deviceContext->RSSetViewports( 1, &vp );

    // Draw
    while ( iterations-- > 0 ) {
        for ( int i = 0; i < steps; i++ ) {
            for ( size_t j = 0; j < renderTargets[ i ].size(); j++ ) {
                rtv[ j ] = renderTargets[ i ][ j ]->Get();
            }
            for ( int j = 0; j < ( float4 ? 1 : 4 ); j++ ) {
                srv[ j ] = shaderResources[ i ][ j ]->Get();
            }
            deviceContext->OMSetRenderTargets( renderTargets[ i ].size(), rtv, 0 );
            deviceContext->PSSetShaderResources( 0, float4 ? 1 : 4, srv );
            deviceContext->PSSetShader( shaders[ i ].Get(), 0, 0 );
            deviceContext->Draw( 4, 0 );
            deviceContext->PSSetShaderResources( 0, 4, reinterpret_cast< ID3D11ShaderResourceView** >( emptyResources ) );
            deviceContext->OMSetRenderTargets( 4, reinterpret_cast< ID3D11RenderTargetView** >( emptyResources ), 0 );
        }
    }
    deviceContext->FinishCommandList( FALSE, commandList.ReleaseAndGetAddressOf() );
}

void DWTSchemeD3D11Deferred::Run( unsigned int iterations ) {
    // Get device context
    ID3D11DeviceContext * deviceContext = context.Get();

    // Draw
    while ( iterations-- > 0 ) {
        deviceContext->ExecuteCommandList( commandList.Get(), FALSE );
    }
}

void DWTSchemeD3D11Deferred::InitD3D( ComPtr<ID3D11Device>& device, ComPtr<ID3D11DeviceContext>& context, ComPtr<ID3D11Buffer>& constantBuffer,
    ComPtr<ID3D11VertexShader>& vertexShader, ComPtr<ID3D11SamplerState>& samplerState ) {
    DWTSchemeD3D11::InitD3D( device, context );

    // Create deferred context
    device->CreateDeferredContext( 0, deferredContext.ReleaseAndGetAddressOf() );
    DWTSchemeD3D11Deferred::constantBuffer = constantBuffer;
    DWTSchemeD3D11Deferred::vertexShader = vertexShader;
    DWTSchemeD3D11Deferred::samplerState = samplerState;

}

DWTSchemeD3D11Deferred::~DWTSchemeD3D11Deferred() {
}
