/*
 * Copyright (c) 2015 Ruwen Hahn <palana@stunned.de>
 *                    John R. Bradley <jrb@turrettech.com>
 *                    Hugh Bailey "Jim" <obs.jim@gmail.com>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

uniform float4x4 ViewProj;
uniform texture2d image;

uniform texture2d previous_image;
uniform float2 dimensions;
uniform int field_order;
uniform bool frame2;

sampler_state textureSampler {
	Filter    = Linear;
	AddressU  = Clamp;
	AddressV  = Clamp;
};

struct VertData {
	float4 pos : POSITION;
	float2 uv  : TEXCOORD0;
};

int3 select(int2 texel, int x, int y)
{
	return int3(texel + int2(x, y), 0);
}

float4 load_at_prev(int2 texel, int x, int y)
{
	return previous_image.Load(select(texel, x, y));
}

float4 load_at_image(int2 texel, int x, int y)
{
	return image.Load(select(texel, x, y));
}

float4 load_at(int2 texel, int x, int y, int field)
{
	if(field == 0)
		return load_at_image(texel, x, y);
	else
		return load_at_prev(texel, x, y);
}

#define YADIF_UPDATE(c, level) 	\
	if(score.c < spatial_score.c) \
	{ \
		spatial_score.c = score.c; \
		spatial_pred.c = (load_at(texel, level, -1, field) + load_at(texel, -level, 1, field)).c / 2; \

#define YADIF_CHECK_ONE(level, c) \
{ \
	float4 score = abs(load_at(texel, -1 + level, 1, field) - load_at(texel, -1 - level, -1, field)) + \
	               abs(load_at(texel, level,      1, field) - load_at(texel, -level,     -1, field)) + \
	               abs(load_at(texel, 1 + level,  1, field) - load_at(texel, 1 - level,  -1, field)); \
	YADIF_UPDATE(c, level) } \
}

#define YADIF_CHECK(level) \
{ \
	float4 score = abs(load_at(texel, -1 + level, 1, field) - load_at(texel, -1 - level, -1, field)) + \
	               abs(load_at(texel, level,      1, field) - load_at(texel, -level,     -1, field)) + \
	               abs(load_at(texel, 1 + level,  1, field) - load_at(texel, 1 - level,  -1, field)); \
	YADIF_UPDATE(r, level) YADIF_CHECK_ONE(level * 2, r) } \
	YADIF_UPDATE(g, level) YADIF_CHECK_ONE(level * 2, g) } \
	YADIF_UPDATE(b, level) YADIF_CHECK_ONE(level * 2, b) } \
	YADIF_UPDATE(a, level) YADIF_CHECK_ONE(level * 2, a) } \
}

float4 texel_at_yadif(int2 texel, int field, bool mode0)
{
	if((texel.y % 2) == field)
		return load_at(texel, 0, 0, field);

	#define YADIF_AVG(x_off, y_off) ((load_at_prev(texel, x_off, y_off) + load_at_image(texel, x_off, y_off))/2)
	float4 c = load_at(texel, 0, 1, field),
	       d = YADIF_AVG(0, 0),
	       e = load_at(texel, 0, -1, field);

	float4 temporal_diff0 = (abs(load_at_prev(texel,  0, 0)      -     load_at_image(texel, 0, 0)))      / 2,
	       temporal_diff1 = (abs(load_at_prev(texel,  0, 1) - c) + abs(load_at_prev(texel,  0, -1) - e)) / 2,
	       temporal_diff2 = (abs(load_at_image(texel, 0, 1) - c) + abs(load_at_image(texel, 0, -1) - e)) / 2,
	       diff = max(temporal_diff0, max(temporal_diff1, temporal_diff2));

	float4 spatial_pred = (c + e) / 2,
	       spatial_score = abs(load_at(texel, -1, 1, field) - load_at(texel, -1, -1, field)) +
	                       abs(c - e) +
	                       abs(load_at(texel, 1,  1, field) - load_at(texel, 1,  -1, field)) - 1;

	YADIF_CHECK(-1)
	YADIF_CHECK(1)

	if (mode0) {
		float4 b = YADIF_AVG(0, 2),
		       f = YADIF_AVG(0, -2);

		float4 max_ = max(d - e, max(d - c, min(b - c, f - e))),
		       min_ = min(d - e, min(d - c, max(b - c, f - e)));

		diff = max(diff, max(min_, -max_));
	} else {
		diff = max(diff, max(min(d - e, d - c), -max(d - e, d - c)));
	}

#define YADIF_SPATIAL(c) \
{ \
	if(spatial_pred.c > d.c + diff.c) \
		spatial_pred.c = d.c + diff.c; \
	else if(spatial_pred.c < d.c - diff.c) \
		spatial_pred.c = d.c - diff.c; \
}

	YADIF_SPATIAL(r)
	YADIF_SPATIAL(g)
	YADIF_SPATIAL(b)
	YADIF_SPATIAL(a)

	return spatial_pred;
}

float4 texel_at_yadif_2x(int2 texel, int field, bool mode0)
{
	field = frame2 ? (1 - field) : field;
	return texel_at_yadif(texel, field, mode0);
}

float4 texel_at_discard(int2 texel, int field)
{
	texel.y = texel.y / 2 * 2;
	return load_at_image(texel, 0, field);
}

float4 texel_at_discard_2x(int2 texel, int field)
{
	field = frame2 ? field : (1 - field);
	return texel_at_discard(texel, field);
}

float4 texel_at_blend(int2 texel, int field)
{
	return (load_at_image(texel, 0, 0) + load_at_image(texel, 0, 1)) / 2;
}

float4 texel_at_blend_2x(int2 texel, int field)
{
	if (!frame2)
		return (load_at_image(texel, 0, 0) +
		        load_at_prev(texel, 0, 1)) / 2;
	else
		return (load_at_image(texel, 0, 0) +
		        load_at_image(texel, 0, 1)) / 2;
}

float4 texel_at_linear(int2 texel, int field)
{
	if ((texel.y % 2) == field)
		return load_at_image(texel, 0, 0);
	return (load_at_image(texel, 0, -1) + load_at_image(texel, 0, 1)) / 2;
}

float4 texel_at_linear_2x(int2 texel, int field)
{
	field = frame2 ? field : (1 - field);
	return texel_at_linear(texel, field);
}

float4 texel_at_yadif_discard(int2 texel, int field)
{
	return (texel_at_yadif(texel, field, true) + texel_at_discard(texel, field)) / 2;
}

float4 texel_at_yadif_discard_2x(int2 texel, int field)
{
	field = frame2 ? (1 - field) : field;
	return (texel_at_yadif(texel, field, true) + texel_at_discard(texel, field)) / 2;
}

int2 pixel_uv(float2 uv)
{
	return int2(uv * dimensions);
}

float4 PSYadifMode0RGBA(VertData v_in) : TARGET
{
	return texel_at_yadif(pixel_uv(v_in.uv), field_order, true);
}

float4 PSYadifMode0RGBA_2x(VertData v_in) : TARGET
{
	return texel_at_yadif_2x(pixel_uv(v_in.uv), field_order, true);
}

float4 PSYadifMode2RGBA(VertData v_in) : TARGET
{
	return texel_at_yadif(pixel_uv(v_in.uv), field_order, false);
}

float4 PSYadifMode2RGBA_2x(VertData v_in) : TARGET
{
	return texel_at_yadif_2x(pixel_uv(v_in.uv), field_order, false);
}

float4 PSYadifDiscardRGBA(VertData v_in) : TARGET
{
	return texel_at_yadif_discard(pixel_uv(v_in.uv), field_order);
}

float4 PSYadifDiscardRGBA_2x(VertData v_in) : TARGET
{
	return texel_at_yadif_discard_2x(pixel_uv(v_in.uv), field_order);
}

float4 PSLinearRGBA(VertData v_in) : TARGET
{
	return texel_at_linear(pixel_uv(v_in.uv), field_order);
}

float4 PSLinearRGBA_2x(VertData v_in) : TARGET
{
	return texel_at_linear_2x(pixel_uv(v_in.uv), field_order);
}

float4 PSDiscardRGBA(VertData v_in) : TARGET
{
	return texel_at_discard(pixel_uv(v_in.uv), field_order);
}

float4 PSDiscardRGBA_2x(VertData v_in) : TARGET
{
	return texel_at_discard_2x(pixel_uv(v_in.uv), field_order);
}

float4 PSBlendRGBA(VertData v_in) : TARGET
{
	return texel_at_blend(pixel_uv(v_in.uv), field_order);
}

float4 PSBlendRGBA_2x(VertData v_in) : TARGET
{
	return texel_at_blend_2x(pixel_uv(v_in.uv), field_order);
}

VertData VSDefault(VertData v_in)
{
	VertData vert_out;
	vert_out.pos = mul(float4(v_in.pos.xyz, 1.0), ViewProj);
	vert_out.uv  = v_in.uv;
	return vert_out;
}

#define TECHNIQUE(rgba_ps) \
technique Draw \
{ \
	pass \
	{ \
		vertex_shader = VSDefault(v_in); \
		pixel_shader  = rgba_ps(v_in); \
	} \
}