#pragma once
#include "DWTScheme.h"

#include <string>
#include <vector>
#include <wrl.h>

#include <d3d11.h>
#include <DirectXMath.h>
#include <D3Dcompiler.h>
#include <DirectXMath.h>

enum RENDER_TARGET_FLAG;

class DWTSchemeD3D11 :
    public DWTScheme {
public:

    DWTSchemeD3D11( std::wstring name, int steps, std::wstring shaderFile, RENDER_TARGET_FLAG * renderTargetFlag = nullptr );
    virtual ~DWTSchemeD3D11() override;

    // Prepare scheme for running with specified resolution
    virtual void  Init( unsigned int resolution ) override;

    // Run scheme
    virtual void Run( unsigned int iterations ) override;

    // Initialize device
    static void InitD3D( Microsoft::WRL::ComPtr<ID3D11Device>& device, Microsoft::WRL::ComPtr<ID3D11DeviceContext>& context );

    // Accessors
    inline unsigned int IsFloat4()              const { return float4; }
    inline unsigned int GetResolution()         const { return resources.GetTextureRes(); }

    std::vector<ID3D11RenderTargetView*> GetInputRTVs();
    std::vector<ID3D11ShaderResourceView*> GetInputSRVs();
    std::vector<ID3D11ShaderResourceView*> GetOutputSRVs();

protected:

    // Resources struct
    struct D3DResources {
        // Swap chain - textures and shader resource views (4x float texture)
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap0Tex_R;
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap0Tex_G;
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap0Tex_B;
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap0Tex_A;

        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap0SRV_R;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap0SRV_G;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap0SRV_B;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap0SRV_A;

        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap0RTV_R;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap0RTV_G;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap0RTV_B;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap0RTV_A;

        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap1Tex_R;
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap1Tex_G;
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap1Tex_B;
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap1Tex_A;

        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap1SRV_R;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap1SRV_G;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap1SRV_B;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap1SRV_A;

        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap1RTV_R;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap1RTV_G;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap1RTV_B;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap1RTV_A;

        // Swap chain - textures and shader resource views (1x float4 texture)
        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap0Tex_RGBA;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap0SRV_RGBA;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap0RTV_RGBA;

        Microsoft::WRL::ComPtr<ID3D11Texture2D>            swap1Tex_RGBA;
        Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>   swap1SRV_RGBA;
        Microsoft::WRL::ComPtr<ID3D11RenderTargetView>     swap1RTV_RGBA;

        void InitTextures( unsigned int resolution, bool float4, bool tex4, bool fill );

        void ReleaseTextures();

        // Accessors
        inline unsigned int GetTextureRes()         const { return textureResolution; }
        inline unsigned int GetResourcesVersion()   const { return resourcesVersion; }
        inline unsigned int InitializedFloat4()     const { return float4; }
        inline unsigned int InitializedTex4()       const { return tex4; }

    protected:
        void InitTexture( unsigned int resolution, DXGI_FORMAT format, D3D11_USAGE usage, UINT bindFlags,
            Microsoft::WRL::ComPtr<ID3D11Texture2D>& texture, Microsoft::WRL::ComPtr<ID3D11RenderTargetView>& renderTargetView,
            Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>& shaderResourceView, const void * data = nullptr, int typeSize = sizeof( float ), int componentsCount = 4 );

        // Texture parameters
        unsigned int    textureResolution = 0;
        bool            float4 = false;
        bool            tex4 = false;

        // Version of resources
        unsigned int    resourcesVersion = 0;
    };

    // Returns pointer to specified smart pointer (ComPtr)
    template<typename Type>
    inline Microsoft::WRL::ComPtr<Type>* PtrToComPtr( Microsoft::WRL::ComPtr<Type>& comPtr ) {
        return reinterpret_cast< Microsoft::WRL::ComPtr<Type>* >( &reinterpret_cast< char& >( comPtr ) );
    }

    // Scheme properties
    bool                float4;

    std::string         stepMehodSuffix = "Step";

    // Shaders and resources
    std::vector<Microsoft::WRL::ComPtr<ID3D11PixelShader>>                          shaders;
    std::vector<Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>*>                  outputShaderResources;
    std::vector<Microsoft::WRL::ComPtr<ID3D11Texture2D>*>                           outputTextures;
    std::vector<std::vector<Microsoft::WRL::ComPtr<ID3D11ShaderResourceView>*>>     shaderResources;
    std::vector<std::vector<Microsoft::WRL::ComPtr<ID3D11RenderTargetView>*>>       renderTargets;

    std::vector<RENDER_TARGET_FLAG> renderTargetFlags;

    std::vector<Microsoft::WRL::ComPtr<ID3D11RenderTargetView>*>
        MultipleRenderTargets( int step, RENDER_TARGET_FLAG renderTargetFlag, int &r, int &g, int &b, int &a );

    // Requied Direct3D resources
    static D3DResources resources;

    // Direct3D device and device context
    static Microsoft::WRL::ComPtr<ID3D11Device>               device;
    static Microsoft::WRL::ComPtr<ID3D11DeviceContext>        context;

    // Current resources version
    unsigned int    resourcesVersion = 1;
};

// Render target flags
enum RENDER_TARGET_FLAG {
    R = 1 << 0,
    G = 1 << 1,
    B = 1 << 2,
    A = 1 << 3,
};

static inline RENDER_TARGET_FLAG operator | ( RENDER_TARGET_FLAG a, RENDER_TARGET_FLAG b ) {
    return static_cast< RENDER_TARGET_FLAG >( static_cast< int >( a ) | static_cast< int >( b ) );
};