#include "postgres.h"
#include "fmgr.h"
#include "windowapi.h"

PG_MODULE_MAGIC;

PG_FUNCTION_INFO_V1(movavg);

Datum movavg(PG_FUNCTION_ARGS);

typedef struct {
	int64	exiting;	/* absolute position of exiting row */
	int64	entering;	/* absolute position of entering row */
	float8	svalue;		/* trans value */
	int		count;		/* the number of accumlated values */
} movavg_data;

Datum
movavg(PG_FUNCTION_ARGS)
{
	WindowObject	winobj = PG_WINDOW_OBJECT();
	bool			isnull, isout;
	int64			cur = WinGetCurrentPosition(winobj);
	int32			range;
	Datum			value;
	int				i;
	bool			const_range;

	/* current row +/- <range> rows are the targets */
	range = DatumGetInt32(WinGetFuncArgCurrent(winobj, 1, &isnull));
	if (isnull || range < 0)
	{
		elog(ERROR, "invalid range");
	}
	/* if it's stable Const value or not */
	const_range = get_fn_expr_arg_stable(fcinfo->flinfo, 1);

	/*
	 * For variable range arguments, we only calculate exact
	 * average for all the target rows. Otherwise we can
	 * optimze it by subtract/add method.
	 */
	if (!const_range)
	{
		float8	svalue = 0.0;
		int		count = 0;

		for(i = cur - range; i <= cur + range; i++)
		{
			value = WinGetFuncArgInPartition(winobj, 0, i,
						WINDOW_SEEK_HEAD, false, &isnull, &isout);
			if (!isnull && !isout)
			{
				svalue += DatumGetFloat8(value);
				count++;
			}
		}

		WinSetMarkPosition(winobj, cur - range);

		if (count > 0)
			PG_RETURN_FLOAT8(svalue / (float8) count);
		PG_RETURN_NULL();
	}
	else
	{
		movavg_data	   *sdata = (movavg_data *)
							WinGetPartitionLocalMemory(winobj, sizeof(movavg_data));

		/* optimize for the single row case */
		if (range == 0)
		{
			value = WinGetFuncArgCurrent(winobj, 0, &isnull);
			if (isnull)
				PG_RETURN_NULL();
			PG_RETURN_DATUM(value);
		}

		/*
		 * The first row is the only special case.
		 * Calculate avarage as usual. Otherwise,
		 * exiting row value is subtracted and entering
		 * row is added then returns result.
		 */
		if (sdata->entering == 0)
		{
			for(i = 0; i <= range; i++)
			{
				value = WinGetFuncArgInPartition(winobj, 0, i,
							WINDOW_SEEK_HEAD, false, &isnull, &isout);
				if (!isnull && !isout)
				{
					sdata->svalue += DatumGetFloat8(value);
					sdata->count++;
				}
			}
			sdata->exiting = -range;
			sdata->entering = range;
		}
		else
		{
			value = WinGetFuncArgInPartition(winobj, 0, sdata->exiting,
							WINDOW_SEEK_HEAD, true, &isnull, &isout);
			if (!isnull && !isout)
			{
				sdata->svalue -= DatumGetFloat8(value);
				sdata->count--;
			}
			sdata->exiting++;

			sdata->entering++;
			value = WinGetFuncArgInPartition(winobj, 0, sdata->entering,
							WINDOW_SEEK_HEAD, false, &isnull, &isout);
			if (!isnull && !isout)
			{
				sdata->svalue += DatumGetFloat8(value);
				sdata->count++;
			}
		}

		if (sdata->count > 0)
			PG_RETURN_FLOAT8(sdata->svalue / (float8) sdata->count);
		PG_RETURN_NULL();
	}
}
