/*
 * lanczos sharper
 * note - this shader is adapted from the GPL bsnes shader, very good stuff
 * there.
 */

uniform float4x4 ViewProj;
uniform texture2d image;
uniform float2 base_dimension;
uniform float2 base_dimension_i;
uniform float undistort_factor = 1.0;

sampler_state textureSampler
{
	AddressU  = Clamp;
	AddressV  = Clamp;
	Filter    = Linear;
};

struct VertData {
	float4 pos : POSITION;
	float2 uv  : TEXCOORD0;
};

struct VertOut {
	float2 uv : TEXCOORD0;
	float4 pos : POSITION;
};

struct FragData {
	float2 uv : TEXCOORD0;
};

VertOut VSDefault(VertData v_in)
{
	VertOut vert_out;
	vert_out.uv  = v_in.uv * base_dimension;
	vert_out.pos = mul(float4(v_in.pos.xyz, 1.0), ViewProj);

	return vert_out;
}

float weight(float x)
{
	float x_pi = x * 3.141592654;
	return 3.0 * sin(x_pi) * sin(x_pi * (1.0 / 3.0)) / (x_pi * x_pi);
}

void weight6(float f_neg, out float3 tap012, out float3 tap345)
{
	tap012 = float3(
		weight(f_neg - 2.0),
		weight(f_neg - 1.0),
		min(1.0, weight(f_neg))); // Replace NaN with 1.0.
	tap345 = float3(
		weight(f_neg + 1.0),
		weight(f_neg + 2.0),
		weight(f_neg + 3.0));

	// Normalize weights
	float sum = tap012.x + tap012.y + tap012.z + tap345.x + tap345.y + tap345.z;
	float sum_i = 1.0 / sum;
	tap012 = tap012 * sum_i;
	tap345 = tap345 * sum_i;
}

float AspectUndistortX(float x, float a)
{
	// The higher the power, the longer the linear part will be.
	return (1.0 - a) * (x * x * x * x * x) + a * x;
}

float AspectUndistortU(float u)
{
	// Normalize texture coord to -1.0 to 1.0 range, and back.
	return AspectUndistortX((u - 0.5) * 2.0, undistort_factor) * 0.5 + 0.5;
}

float2 undistort_coord(float xpos, float ypos)
{
	return float2(AspectUndistortU(xpos), ypos);
}

float4 undistort_pixel(float xpos, float ypos)
{
	return image.Sample(textureSampler, undistort_coord(xpos, ypos));
}

float4 undistort_line(float3 xpos012, float3 xpos345, float ypos, float3 rowtap012,
		float3 rowtap345)
{
	return
		undistort_pixel(xpos012.x, ypos) * rowtap012.x +
		undistort_pixel(xpos012.y, ypos) * rowtap012.y +
		undistort_pixel(xpos012.z, ypos) * rowtap012.z +
		undistort_pixel(xpos345.x, ypos) * rowtap345.x +
		undistort_pixel(xpos345.y, ypos) * rowtap345.y +
		undistort_pixel(xpos345.z, ypos) * rowtap345.z;
}

float4 DrawLanczos(FragData f_in, bool undistort)
{
	float2 pos = f_in.uv;
	float2 pos2 = floor(pos - 0.5) + 0.5;
	float2 f_neg = pos2 - pos;

	float3 rowtap012, rowtap345;
	weight6(f_neg.x, rowtap012, rowtap345);

	float3 coltap012, coltap345;
	weight6(f_neg.y, coltap012, coltap345);

	float2 uv2 = pos2 * base_dimension_i;
	float2 uv1 = uv2 - base_dimension_i;
	float2 uv0 = uv1 - base_dimension_i;
	float2 uv3 = uv2 + base_dimension_i;
	float2 uv4 = uv3 + base_dimension_i;
	float2 uv5 = uv4 + base_dimension_i;

	if (undistort) {
		float3 xpos012 = float3(uv0.x, uv1.x, uv2.x);
		float3 xpos345 = float3(uv3.x, uv4.x, uv5.x);
		return undistort_line(xpos012, xpos345, uv0.y, rowtap012, rowtap345) * coltap012.x +
		       undistort_line(xpos012, xpos345, uv1.y, rowtap012, rowtap345) * coltap012.y +
		       undistort_line(xpos012, xpos345, uv2.y, rowtap012, rowtap345) * coltap012.z +
		       undistort_line(xpos012, xpos345, uv3.y, rowtap012, rowtap345) * coltap345.x +
		       undistort_line(xpos012, xpos345, uv4.y, rowtap012, rowtap345) * coltap345.y +
		       undistort_line(xpos012, xpos345, uv5.y, rowtap012, rowtap345) * coltap345.z;
	}

	float u_weight_sum = rowtap012.z + rowtap345.x;
	float u_middle_offset = rowtap345.x * base_dimension_i.x / u_weight_sum;
	float u_middle = uv2.x + u_middle_offset;

	float v_weight_sum = coltap012.z + coltap345.x;
	float v_middle_offset = coltap345.x * base_dimension_i.y / v_weight_sum;
	float v_middle = uv2.y + v_middle_offset;

	float2 coord_limit = base_dimension - 0.5;
	float2 coord0_f = max(uv0 * base_dimension, 0.5);
	float2 coord1_f = max(uv1 * base_dimension, 0.5);
	float2 coord4_f = min(uv4 * base_dimension, coord_limit);
	float2 coord5_f = min(uv5 * base_dimension, coord_limit);

	int2 coord0 = int2(coord0_f);
	int2 coord1 = int2(coord1_f);
	int2 coord4 = int2(coord4_f);
	int2 coord5 = int2(coord5_f);

	float4 row0 = image.Load(int3(coord0, 0)) * rowtap012.x;
	row0 += image.Load(int3(coord1.x, coord0.y, 0)) * rowtap012.y;
	row0 += image.Sample(textureSampler, float2(u_middle, uv0.y)) * u_weight_sum;
	row0 += image.Load(int3(coord4.x, coord0.y, 0)) * rowtap345.y;
	row0 += image.Load(int3(coord5.x, coord0.y, 0)) * rowtap345.z;
	float4 total = row0 * coltap012.x;

	float4 row1 = image.Load(int3(coord0.x, coord1.y, 0)) * rowtap012.x;
	row1 += image.Load(int3(coord1.x, coord1.y, 0)) * rowtap012.y;
	row1 += image.Sample(textureSampler, float2(u_middle, uv1.y)) * u_weight_sum;
	row1 += image.Load(int3(coord4.x, coord1.y, 0)) * rowtap345.y;
	row1 += image.Load(int3(coord5.x, coord1.y, 0)) * rowtap345.z;
	total += row1 * coltap012.y;

	float4 row23 = image.Sample(textureSampler, float2(uv0.x, v_middle)) * rowtap012.x;
	row23 += image.Sample(textureSampler, float2(uv1.x, v_middle)) * rowtap012.y;
	row23 += image.Sample(textureSampler, float2(u_middle, v_middle)) * u_weight_sum;
	row23 += image.Sample(textureSampler, float2(uv4.x, v_middle)) * rowtap345.y;
	row23 += image.Sample(textureSampler, float2(uv5.x, v_middle)) * rowtap345.z;
	total += row23 * v_weight_sum;

	float4 row4 = image.Load(int3(coord0.x, coord4.y, 0)) * rowtap012.x;
	row4 += image.Load(int3(coord1.x, coord4.y, 0)) * rowtap012.y;
	row4 += image.Sample(textureSampler, float2(u_middle, uv4.y)) * u_weight_sum;
	row4 += image.Load(int3(coord4.x, coord4.y, 0)) * rowtap345.y;
	row4 += image.Load(int3(coord5.x, coord4.y, 0)) * rowtap345.z;
	total += row4 * coltap345.y;

	float4 row5 = image.Load(int3(coord0.x, coord5.y, 0)) * rowtap012.x;
	row5 += image.Load(int3(coord1.x, coord5.y, 0)) * rowtap012.y;
	row5 += image.Sample(textureSampler, float2(u_middle, uv5.y)) * u_weight_sum;
	row5 += image.Load(int3(coord4.x, coord5.y, 0)) * rowtap345.y;
	row5 += image.Load(int3(coord5, 0)) * rowtap345.z;
	total += row5 * coltap345.z;

	return total;
}

float4 PSDrawLanczosRGBA(FragData f_in, bool undistort) : TARGET
{
	return DrawLanczos(f_in, undistort);
}

float4 PSDrawLanczosRGBADivide(FragData f_in) : TARGET
{
	float4 rgba = DrawLanczos(f_in, false);
	float alpha = rgba.a;
	float multiplier = (alpha > 0.0) ? (1.0 / alpha) : 0.0;
	return float4(rgba.rgb * multiplier, alpha);
}

technique Draw
{
	pass
	{
		vertex_shader = VSDefault(v_in);
		pixel_shader  = PSDrawLanczosRGBA(f_in, false);
	}
}

technique DrawAlphaDivide
{
	pass
	{
		vertex_shader = VSDefault(v_in);
		pixel_shader  = PSDrawLanczosRGBADivide(f_in);
	}
}

technique DrawUndistort
{
	pass
	{
		vertex_shader = VSDefault(v_in);
		pixel_shader  = PSDrawLanczosRGBA(f_in, true);
	}
}