#include "Renderer.hpp"

#include "Window.hpp"
#include "FrameBuffer.hpp"
#include "Rasterizer.hpp"
#include "VertexBuffer.hpp"
#include "IndexBuffer.hpp"
#include "PS_2_0Shader.hpp"
#include "VS_2_0Shader.hpp"
#include "PixelPipeline.hpp"
#include "Vertex.hpp"
#include "Error.hpp"
#include "Matrix.hpp"
#include "Texture.hpp"
#include "Viewport.hpp"
#include "Clipper.hpp"

namespace swShader
{
	Renderer::Renderer(RenderTarget *renderTarget)
	{
		if(!renderTarget) throw Error("Could not create renderer");

		Context::init();

		clipper = new Clipper();
		viewport = new Viewport(renderTarget);
		rasterizer = new Rasterizer(renderTarget);

		W = 0;
		H = 0;

		tanFOV = tan(rad(90 / 2));
		nearClip = 1;
		farClip = 1000;

		M = 1;
		V = 1;
		B = Matrix(1, 0, 0, 0,
		           0, 0, 1, 0,
		           0, 1, 0, 0,
		           0, 0, 0, 1);
		P = 0;

		updateModelMatrix = true;
		updateViewMatrix = true;
		updateBaseMatrix = true;
		updateProjectionMatrix = true;

		updatePixelShader = true;
		updateVertexShader = true;

		updateClipper = true;

		pixelShaderFile = 0;
		pixelShaderCache = new FIFOCache<State, PixelShader>(64);
		pixelShader = 0;

		vertexShaderFile = 0;
		vertexShaderCache = new FIFOCache<char*, VertexShader>(64);
		pixelShader = 0;

		top = 0;
		for(int i = 0; i < 16; i++) index[i] = -1;
	}

	Renderer::~Renderer()
	{
		delete clipper;
		clipper = 0;
		delete viewport;
		viewport = 0;
		delete rasterizer;
		rasterizer = 0;

		releaseTextures();

		delete[] pixelShaderFile;
		pixelShaderFile = 0;
		delete pixelShaderCache;
		pixelShaderCache = 0;

		delete[] vertexShaderFile;
		vertexShaderFile = 0;
		delete vertexShaderCache;
		vertexShaderCache = 0;
	}

	void Renderer::drawPrimitive(const VertexBuffer *VB, const IndexBuffer *IB)
	{
		if(!clipper || !rasterizer || !viewport) throw INTERNAL_ERROR;
		if(!VB || !IB) throw Error("No primitive specified");

		update(FVF);

		vertexShader->setPositionStream(VB->position[0]);
		vertexShader->setTexCoordStream(VB->texCoord[0]);

		setVertexShaderConstantF(0, PBVM[0]);
		setVertexShaderConstantF(1, PBVM[1]);
		setVertexShaderConstantF(2, PBVM[2]);
		setVertexShaderConstantF(3, PBVM[3]);

		FVF = vertexShader->getOutputFormat() | FVF_RHW;

		const float L = viewport->getLeft();
		const float T = viewport->getTop() + H;

		for(int i = 0; i < IB->numFaces; i++)
		{
			const IndexBuffer::Face &face = IB->face[i];

			const XVertex *V1 = 0; for(int i = 0; i < 16; i++) if(face[0] == index[i]) V1 = &vertexCache[i];
			XVertex v1;
			if(V1)
			{
				v1 = *V1;
			}
			else
			{
				top = (top + 1) & 16;
				vertexShader->setOutputVertex(&vertexCache[top]);
				vertexShader->process(face[0]);
				v1 = vertexCache[top];
			}

			const XVertex *V2 = 0; for(int i = 0; i < 16; i++) if(face[1] == index[i]) V2 = &vertexCache[i];
			XVertex v2;
			if(V2)
			{
				v2 = *V2;
			}
			else
			{
				top = (top + 1) & 16;
				vertexShader->setOutputVertex(&vertexCache[top]);
				vertexShader->process(face[1]);
				v2 = vertexCache[top];
			}

			const XVertex *V3 = 0; for(int i = 0; i < 16; i++) if(face[2] == index[i]) V3 = &vertexCache[i];
			XVertex v3;
			if(V3)
			{
				v3 = *V3;
			}
			else
			{
				top = (top + 1) & 16;
				vertexShader->setOutputVertex(&vertexCache[top]);
				vertexShader->process(face[2]);
				v3 = vertexCache[top];
			}
		
			XVertex **V = clipper->clipTriangle(v1, v2, v3, FVF);
			int n = clipper->getNumVertices();
			if(n < 3) continue;

			for(int i = 0; i < n; i++)
			{
				XVertex &v = *V[i];

				const float RHW = 1.0f / v.w;
					
				v.x = L + W * v.x * RHW;
				v.y = T - H * v.y * RHW;
				v.z = v.z * RHW;
				v.w = RHW;

				for(int t = 0; t < FVF.textureCount(); t++)
				{
					v.T[t].u *= RHW;
					v.T[t].v *= RHW;
				}
			}

			renderPolygon(V, n, FVF);
		}

		for(int i = 0; i < 16; i++) index[i] = -1;
	}

	const State &Renderer::status() const
	{
		static State state;

		state.setShaderFile(pixelShaderFile);

		const int a = 0 + BITS(COLOR_LAST);		// colorDepth
		const int b = a + BITS(DEPTH_LAST);		// depthCompareMode
		const int c = b + BITS(ALPHA_LAST);		// alphaCompareMode
		const int d = c + 1;					// depthWriteEnable
		const int e = d + 1;					// alphaTestEnable
		const int f = e + 1;					// alphaBlendEnable
		const int g = f + BITS(SHADING_LAST);	// shadingMode
		const int h = g + 1;					// specularEnable
		const int i = h + BITS(BLEND_LAST);		// sourceBlend
		const int j = i + BITS(BLEND_LAST);		// destinationBlend

		// Test whether status fits in 32-bit
		META_ASSERT(j <= 32);

		state.setPipelineState(colorDepth			<< 0 |
		                       depthCompareMode		<< a |
		                       alphaCompareMode		<< b |
		                       depthWriteEnable		<< c |
		                       alphaTestEnable		<< d |
							   alphaBlendEnable		<< e |
		                       shadingMode			<< f |
							   specularEnable		<< g |
		                       sourceBlendFactor	<< h |
		                       destBlendFactor		<< i);

		for(unsigned int i = 0; i < 16; i++)
		{
			state.setSamplerState(i, sampler[i].status());
		}

		return state;
	}

	void Renderer::update(FVFFlags FVF)
	{
		this->FVF = FVF;

		if(updateProjectionMatrix ||
		   W != viewport->getWidth() ||
		   H != viewport->getHeight())
		{
			W = viewport->getWidth();
			H = viewport->getHeight();

			const float D = 0.5f * W / tanFOV;
			const float F = farClip;
			const float N = nearClip;
			const float Q = F / (F - N);

			P = Matrix(D / W,  0,     0.5f,  0,
			           0,      D / H, 0.5f,  0,
			           0,      0,     Q,    -Q * N,
			           0,      0,     1,     0);

			PB = P * B;
			PBV = PB * V;
			PBVM = PBV * M;

			updateModelMatrix = false;
			updateViewMatrix = false;
			updateBaseMatrix = false;
			updateProjectionMatrix = false;

			updateClipper = true;
		}

		if(updateBaseMatrix)
		{
			PB = P * B;
			PBV = PB * V;
			PBVM = PBV * M;

			updateModelMatrix = false;
			updateViewMatrix = false;
			updateBaseMatrix = false;

			updateClipper = true;
		}

		if(updateViewMatrix)
		{
			PBV = PB * V;
			PBVM = PBV * M;

			updateModelMatrix = false;
			updateViewMatrix = false;

			updateClipper = true;
		}

		if(updateModelMatrix)
		{
			PBVM = PBV * M;

			updateModelMatrix = false;

			updateClipper = true;
		}

		if(updateClipper || updateVertexShader)
		{
			if(!clipper) throw INTERNAL_ERROR;

			int flags = clipper->getClipFlags();

			if(!vertexShaderFile)   // Fixed-function pipeline: clip plane in world space
			{
				if(flags & Clipper::CLIP_PLANE0) clipper->setClipPlane(0, PBV * Plane(plane[0]));
				if(flags & Clipper::CLIP_PLANE1) clipper->setClipPlane(1, PBV * Plane(plane[1]));
				if(flags & Clipper::CLIP_PLANE2) clipper->setClipPlane(2, PBV * Plane(plane[2]));
				if(flags & Clipper::CLIP_PLANE3) clipper->setClipPlane(3, PBV * Plane(plane[3]));
				if(flags & Clipper::CLIP_PLANE4) clipper->setClipPlane(4, PBV * Plane(plane[4]));
				if(flags & Clipper::CLIP_PLANE5) clipper->setClipPlane(5, PBV * Plane(plane[5]));
			}
			else   // Programmable pipeline: clip plane in clip space
			{
				// Transform from [-1,1]x[-1,1]x[0,1] to [0,1]x[0,1]x[0,1]
				Matrix X = Matrix(0.5, 0, 0, 0.5,
				                  0.5, 1, 0, 0.5,
				                  0,   0, 1, 0,
				                  0,   0, 0, 1);

				if(flags & Clipper::CLIP_PLANE0) clipper->setClipPlane(0, X * Plane(plane[0]));
				if(flags & Clipper::CLIP_PLANE1) clipper->setClipPlane(1, X * Plane(plane[1]));
				if(flags & Clipper::CLIP_PLANE2) clipper->setClipPlane(2, X * Plane(plane[2]));
				if(flags & Clipper::CLIP_PLANE3) clipper->setClipPlane(3, X * Plane(plane[3]));
				if(flags & Clipper::CLIP_PLANE4) clipper->setClipPlane(4, X * Plane(plane[4]));
				if(flags & Clipper::CLIP_PLANE5) clipper->setClipPlane(5, X * Plane(plane[5]));
			}

			updateClipper = false;
		}

		if(updatePixelShader)
		{
			State state = status();

			pixelShader = pixelShaderCache->query(state);

			if(!pixelShader)   // Create one
			{
				if(pixelShaderFile)
				{
					pixelShader = pixelShaderCache->add(state, new PS_2_0Shader(pixelShaderFile));
				}
				else
				{
					pixelShader = pixelShaderCache->add(state, new PixelPipeline());
				}
			}

			updatePixelShader = false;
		}

		if(updateVertexShader)
		{
			vertexShader = vertexShaderCache->query(vertexShaderFile);

			if(!vertexShader)   // Create one
			{
				if(vertexShaderFile)
				{
					vertexShader = vertexShaderCache->add(vertexShaderFile, new VS_2_0Shader(vertexShaderFile));
				}
				else
				{
					throw Error("Fixed-function vertex pipeline unimplemented");
				}
			}

			updateVertexShader = false;
		}
	}

	void Renderer::setPixelShader(const char *pixelShaderFile)
	{
		delete[] this->pixelShaderFile;
		this->pixelShaderFile = strdup(pixelShaderFile);

		updatePixelShader = true;
	}

	void Renderer::setVertexShader(const char *vertexShaderFile)
	{
		delete[] this->vertexShaderFile;
		this->vertexShaderFile = strdup(vertexShaderFile);

		updateVertexShader = true;
	}

	void Renderer::setTextureMap(int stage, Texture *texture)
	{
		if(stage < 0 || stage >= 16) throw Error("Texture stage index out of [0, 15] range: %d", stage);

		sampler[stage].setTextureMap(texture);
	}

	void Renderer::setTexCoordIndex(int stage, int texCoordIndex)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setTexCoordIndex(texCoordIndex);
	}

	void Renderer::releaseTextures()
	{
		for(int stage = 0; stage < 16; stage++)
		{
			sampler[stage].releaseTexture();
		}
	}

	void Renderer::setStageOperation(int stage, Sampler::StageOperation stageOperation)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setStageOperation(stageOperation);
	}

	void Renderer::setFirstArgument(int stage, Sampler::SourceArgument firstArgument)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setFirstArgument(firstArgument);
	}

	void Renderer::setSecondArgument(int stage, Sampler::SourceArgument secondArgument)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setSecondArgument(secondArgument);
	}

	void Renderer::setThirdArgument(int stage, Sampler::SourceArgument thirdArgument)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setThirdArgument(thirdArgument);
	}

	void Renderer::setFirstModifier(int stage, Sampler::ArgumentModifier firstModifier)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setFirstModifier(firstModifier);
	}

	void Renderer::setSecondModifier(int stage, Sampler::ArgumentModifier secondModifier)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setSecondModifier(secondModifier);
	}

	void Renderer::setThirdModifier(int stage, Sampler::ArgumentModifier thirdModifier)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setThirdModifier(thirdModifier);
	}

	void Renderer::setDestinationArgument(int stage, Sampler::DestinationArgument destinationArgument)
	{
		if(stage < 0 || stage >= 8) throw Error("Texture stage index out of [0, 7] range: %d", stage);

		updatePixelShader |= sampler[stage].setDestinationArgument(destinationArgument);
	}

	void Renderer::setTextureFilter(int stage, Sampler::FilterType textureFilter)
	{
		if(stage < 0 || stage >= 16) throw Error("Texture stage index out of [0, 15] range: %d", stage);

		updatePixelShader |= sampler[stage].setTextureFilter(textureFilter);
	}

	void Renderer::setAddressingMode(int stage, Sampler::AddressingMode addressMode)
	{
		if(stage < 0 || stage >= 16) throw Error("Texture stage index out of [0, 15] range: %d", stage);

		updatePixelShader |= sampler[stage].setAddressingMode(addressMode);
	}

	void Renderer::setDepthCompare(DepthCompareMode depthCompareMode)
	{
		updatePixelShader |= (this->depthCompareMode != depthCompareMode);
		this->depthCompareMode = depthCompareMode;
	}

	void Renderer::setAlphaCompare(AlphaCompareMode alphaCompareMode)
	{
		updatePixelShader |= (this->depthCompareMode != depthCompareMode);
		this->depthCompareMode = depthCompareMode;
	}

	void Renderer::setDepthWriteEnable(bool depthWriteEnable)
	{
		updatePixelShader |= (this->depthWriteEnable != depthWriteEnable);
		this->depthWriteEnable = depthWriteEnable;
	}

	void Renderer::setAlphaTestEnable(bool alphaTestEnable)
	{
		updatePixelShader |= (this->alphaTestEnable != alphaTestEnable);
		this->alphaTestEnable = alphaTestEnable;
	}

	void Renderer::setCullMode(CullMode cullMode)
	{
		updatePixelShader |= (this->cullMode != cullMode);
		this->cullMode = cullMode;
	}

	void Renderer::setShadingMode(ShadingMode shadingMode)
	{
		updatePixelShader |= (this->shadingMode != shadingMode);
		this->shadingMode = shadingMode;
	}

	void Renderer::setSpecularEnable(bool specularEnable)
	{
		updatePixelShader |= (this->specularEnable != specularEnable);
		this->specularEnable = specularEnable;
	}

	void Renderer::setAlphaBlendEnable(bool alphaBlendEnable)
	{
		updatePixelShader |= (this->alphaBlendEnable != alphaBlendEnable);
		this->alphaBlendEnable = alphaBlendEnable;
	}

	void Renderer::setSourceBlendFactor(BlendFactor sourceBlendFactor)
	{
		updatePixelShader |= (this->sourceBlendFactor != sourceBlendFactor);
		this->sourceBlendFactor = sourceBlendFactor;
	}

	void Renderer::setDestBlendFactor(BlendFactor destBlendFactor)
	{
		updatePixelShader |= (this->destBlendFactor != destBlendFactor);
		this->destBlendFactor = destBlendFactor;
	}

	void Renderer::setAlphaReference(int alphaReference)
	{
		this->alphaReference = alphaReference;
	}

	void Renderer::renderPolygon(XVertex **V, int n, FVFFlags FVF)
	{
		scanline = pixelShader->executable();

		for(int i = 2; i < n; i++)
		{
			const XVertex &V1 = *V[0];
			const XVertex &V2 = *V[i - 1];
			const XVertex &V3 = *V[i];

			float area = 0.5f * ((V1.x - V2.x) * (V3.y - V1.y) - (V1.y - V2.y) * (V3.x - V1.x));

			if(cullMode != CULL_NONE)
			{
				if(cullMode == CULL_CLOCKWISE)
				{
					if(area <= 0) continue;
				}
				else if(cullMode == CULL_COUNTERCLOCKWISE)
				{
					if(area >= 0) continue;
				}
				else
				{
					throw INTERNAL_ERROR;
				}
			}

			rasterizer->renderTriangle(&V1, &V2, &V3, FVF);
		}
	}

	void Renderer::setPixelShaderConstantF(int index, const float value[4])
	{
		if(!pixelShader)
		{
			updatePixelShader = true;
			update(FVF);
		}

		pixelShader->setConstant(index, value);
	}

	void Renderer::setVertexShaderConstantF(int index, const float value[4])
	{
		if(!vertexShader)
		{
			updateVertexShader = true;
			update(FVF);
		}

		vertexShader->setFloatConstant(index, value);
	}

	void Renderer::setVertexShaderConstantI(int index, const int value[4])
	{
		if(!vertexShader)
		{
			updateVertexShader = true;
			update(FVF);
		}

		vertexShader->setIntegerConstant(index, value);
	}

	void Renderer::setVertexShaderConstantB(int index, bool boolean)
	{
		if(!vertexShader)
		{
			updateVertexShader = true;
			update(FVF);
		}

		vertexShader->setBooleanConstant(index, boolean);
	}

	void Renderer::setModelTransform(const Matrix &M)
	{
		this->M = M;
		updateModelMatrix = true;
	}

	void Renderer::setViewTransform(const Matrix &V)
	{
		this->V = V;
		updateViewMatrix = true;
	}

	void Renderer::setBaseTransform(const Matrix &B)
	{
		this->B = B;
		updateBaseMatrix = true;
	}

	void Renderer::setFOV(float FOV)
	{
		tanFOV = tan(rad(FOV / 2));
		updateProjectionMatrix = true;
	}

	void Renderer::setNearClip(float nearClip)
	{
		this->nearClip = nearClip;
		updateProjectionMatrix = true;
	}

	void Renderer::setFarClip(float farClip)
	{
		this->farClip = farClip;
		updateProjectionMatrix = true;
	}

	void Renderer::setClipFlags(int flags)
	{
		if(!clipper) throw INTERNAL_ERROR;

		return clipper->setClipFlags(flags);
	}

	void Renderer::setClipPlane(int index, const float plane[4])
	{
		if(index < 0 || index >= 6) throw Error("User-defined clipping plane index out of range [0, 5]");

		this->plane[index][0] = plane[0];
		this->plane[index][1] = plane[1];
		this->plane[index][2] = plane[2];
		this->plane[index][3] = plane[3];

		updateClipper = true;
	}

	const float *Renderer::getClipPlane(int index)
	{
		if(index < 0 || index >= 6) throw Error("User-defined clipping plane index out of range [0, 5]");

		return plane[index];
	}
}