#include "ProfilerD3D11.h"
#include "DXHelper.h"

using namespace std;
using Microsoft::WRL::ComPtr;

ProfilerD3D11::ProfilerD3D11( unsigned int frameLatency, ComPtr<ID3D11Device>& device, ComPtr<ID3D11DeviceContext>& context ) :
    frameLatency( frameLatency ),
    device( device ),
    context( context ) {
    Init();
}

void ProfilerD3D11::Init() {
    // Allocate vector of queries for measurement
    perFrameQueries.resize( frameLatency );

    // Create queries for frames
    for ( unsigned int i = 0; i < frameLatency; i++ ) {
        HRESULT hr;

        // Create disjoint query
        D3D11_QUERY_DESC queryDesc = { D3D11_QUERY_TIMESTAMP_DISJOINT, 0 };
        hr = device->CreateQuery( &queryDesc, perFrameQueries[ i ].disjointQuery.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr );

        // Create timestamp queries
        queryDesc.Query = D3D11_QUERY_TIMESTAMP;
        hr = device->CreateQuery( &queryDesc, perFrameQueries[ i ].timestampQueryBegin.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr );
        hr = device->CreateQuery( &queryDesc, perFrameQueries[ i ].timestampQueryEnd.ReleaseAndGetAddressOf() );
        ThrowIfFailed( hr );
    }

    currentFrame = 0;
}

ProfilerD3D11::~ProfilerD3D11() {
}

void ProfilerD3D11::Begin() {
    // Current frame queries (begin)
    PerFrameQueries& currentQueries = perFrameQueries[ currentFrame ];
    context->Begin( currentQueries.disjointQuery.Get() );
    context->End( currentQueries.timestampQueryBegin.Get() );
}

double ProfilerD3D11::End() {
    // Current frame queries (end)
    PerFrameQueries& currentQueries = perFrameQueries[ currentFrame ];
    context->End( currentQueries.timestampQueryEnd.Get() );
    context->End( currentQueries.disjointQuery.Get() );

    ++currentFrame;
    currentFrame %= frameLatency;

    // Output frame queries (output with latency)
    PerFrameQueries& outputQueries = perFrameQueries[ currentFrame ];

    // Wait for data
    while ( context->GetData( outputQueries.disjointQuery.Get(), NULL, 0, 0 ) == S_FALSE ) {
        Sleep( 1 );
    }

    // Get timestamp disjoint
    D3D11_QUERY_DATA_TIMESTAMP_DISJOINT timestampDisjoint;
    if ( context->GetData( outputQueries.disjointQuery.Get(), &timestampDisjoint, sizeof( timestampDisjoint ), 0 ) != S_OK || timestampDisjoint.Disjoint ) {
        return 0.0;
    }

    // Begin timestamp
    UINT64 timestampBegin;
    if ( context->GetData( outputQueries.timestampQueryBegin.Get(), &timestampBegin, sizeof( UINT64 ), 0 ) != S_OK ) {
        return 0.0;
    }
    // End timestamp
    UINT64 timestampEnd;
    if ( context->GetData( outputQueries.timestampQueryEnd.Get(), &timestampEnd, sizeof( UINT64 ), 0 ) != S_OK ) {
        return 0.0;
    }

    // Compute frame time
    double begin        = static_cast< double > ( timestampBegin );
    double end          = static_cast< double > ( timestampEnd );
    double frequency    = static_cast< double > ( timestampDisjoint.Frequency );

    // Resturn frame time in seconds
    return ( end - begin ) / frequency;
}

void ProfilerD3D11::Reset() {
    perFrameQueries.clear();
    Init();
}
