/*-------------------------------------------------------------------------
 *
 * pg_stat_statements.c
 *
 *-------------------------------------------------------------------------
 */
#include "postgres.h"

#include "access/hash.h"
#include "funcapi.h"
#include "mb/pg_wchar.h"
#include "miscadmin.h"
#include "optimizer/planner.h"
#include "portability/instr_time.h"
#include "storage/bufmgr.h"
#include "storage/spin.h"
#include "tcop/tcopprot.h"
#include "utils/builtins.h"
#include "tcop/pquery.h"

extern PGDLLIMPORT bool	log_executor_stats;
extern PGDLLIMPORT bool	log_statement_stats;
extern PGDLLIMPORT int	pgstat_track_activity_query_size;
extern PGDLLIMPORT long int ReadBufferCount;
extern PGDLLIMPORT long int ReadLocalBufferCount;
extern PGDLLIMPORT long int BufferHitCount;
extern PGDLLIMPORT long int LocalBufferHitCount;
extern PGDLLIMPORT long int BufferFlushCount;
extern PGDLLIMPORT long int LocalBufferFlushCount;

/*
 * FIXME: We miscalculate multiple queries and RULEs presently. They
 * have multiple plans and executions, but have the same top-lovel
 * query text. The statistics are wrongly summed into the query text,
 * not into the individual queries nor the definition queries of RULEs.
 */

/* #define PGSS_DEBUG */

PG_MODULE_MAGIC;

#define PGSS_GUC(name)		("statistics." name)

#define MAX_STATEMENTS_DEFAULT	1000
#define CHUNK_SIZE				MAXALIGN(pgstat_track_activity_query_size)

/*
 * XXX: Should USAGE_PLANNED reflect plan cost?
 * XXX: Should USAGE_EXECUTED reflect execution time and/or buffer usage?
 */
#define USAGE_INITIAL			(1.0)	/* including initial planning */
#define USAGE_PLANNED			(1.0)	/* usage per planning */
#define USAGE_EXECUTED			(1.0)	/* usage per execution */
#define USAGE_DECREASE_FACTOR	(0.99)	/* decreased every entry_dealloc */
#define USAGE_DEALLOC_PERCENT	5		/* free this % of entries at once */

/*
 * Chunk buffer for query text.
 */
typedef union Chunk Chunk;
union Chunk
{
	char	query[1];	/* query text. true size is [CHUNK_SIZE] */
	Chunk  *next;		/* next free chunk */
};

/*
 * Statistics per statement
 */
typedef struct Entry
{
	uint32		tag;		/* hash value of query */
	Oid			userid;		/* user oid */
	Oid			dbid;		/* database oid */
	Chunk	   *chunk;		/* query text */
	int64		planned;	/* # of planned */
	int64		calls;		/* # of executed */
	double		total_cost;	/* total plan cost */
	instr_time	total_time;	/* total execution time */
	int64		gets;		/* # of buffer gets during execution */
	int64		reads;		/* # of buffer reads during execution */
	int64		writes;		/* # of buffer writes during execution */
	double		usage;		/* usage factor */
	slock_t		mutex;		/* protect above fields */
} Entry;

/*
 * Global shared state
 */
typedef struct pgStatStmt
{
	LWLockId	lock;				/* protect fields and hash */
	int			num_statements;		/* # of entries in hash */
	int			chunk_size;			/* max query length in bytes */
	Chunk	   *free_chunks;		/* single linked list of free chunks */
	Chunk		chunks[1];			/* chunk buffers */
} pgStatStmt;

/*---- GUC variables ----*/

static bool			track_statements = true;
static int			max_statements = 0;

/*---- Local variables ----*/

static planner_hook_type		prev_planner_hook = NULL;
static ExecutorRun_hook_type	prev_ExecutorRun_hook = NULL;
static pgStatStmt			   *pgss;
static HTAB					   *pgss_hash;

/*
 * To collect only top-level queries, pgss_toplevel is set to false
 * during planning or executing some queries.
 */
static bool	pgss_toplevel = true;

/*---- Function declarations ----*/

void	_PG_init(void);
void	_PG_fini(void);
Datum	pg_stat_statements_reset(PG_FUNCTION_ARGS);
Datum	pg_stat_statements(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(pg_stat_statements_reset);
PG_FUNCTION_INFO_V1(pg_stat_statements);

static uint32 pgss_tag(Oid userid, Oid dbid, const char *query);
static PlannedStmt *pgss_planner(Query *parse, int cursorOptions,
                                 ParamListInfo boundParams);
static TupleTableSlot *pgss_ExecutorRun(QueryDesc *queryDesc,
                                        ScanDirection direction,
                                        long count);
static void pgss_planned(const char *query, PlannedStmt *plan);
static void pgss_executed(const char *query, PlannedStmt *plan,
		instr_time duration, long gets, long reads, long writes);

static Size	pgss_memsize(void);
static void pgss_init(void);
static volatile Entry *entry_alloc(uint32 tag, Oid userid, Oid dbid,
							const char *query, double cost, bool *found);
static void entry_dealloc(void);
static Chunk *chunk_alloc(void);
static void chunk_free(Chunk *chunk);
static void chunk_reset(void);

static bool max_statements_assign(int newval, bool doit, GucSource source);
static bool track_statements_assign(bool newval, bool doit, GucSource source);

/*---- pg_stat_statements functions ----*/

void
_PG_init(void)
{
	RequestAddinShmemSpace(pgss_memsize());
	RequestAddinLWLocks(1);

	/* install planner_hook */
	if (planner_hook)
		prev_planner_hook = planner_hook;
	else
		prev_planner_hook = standard_planner;
	planner_hook = pgss_planner;

	/* install ExecutorRun_hook */
	if (ExecutorRun_hook)
		prev_ExecutorRun_hook = ExecutorRun_hook;
	else
		prev_ExecutorRun_hook = standard_ExecutorRun;
	ExecutorRun_hook = pgss_ExecutorRun;
}

void
_PG_fini(void)
{
	/* uninstall planner_hook */
	if (prev_planner_hook == standard_planner)
		planner_hook = NULL;
	else
		planner_hook = prev_planner_hook;

	/* uninstall ExecutorRun_hook */
	if (prev_ExecutorRun_hook == standard_ExecutorRun)
		ExecutorRun_hook = NULL;
	else
		ExecutorRun_hook = prev_ExecutorRun_hook;
}

static uint32
pgss_tag(Oid userid, Oid dbid, const char *query)
{
	return oid_hash(&userid, sizeof(Oid)) ^
		   oid_hash(&dbid, sizeof(Oid)) ^
		   DatumGetUInt32(hash_any((const unsigned char *) query,
								   (int) strlen(query)));
}

static void
pgss_planned(const char *query, PlannedStmt *plan)
{
	volatile Entry *entry;
	Oid				userid;
	Oid				dbid;
	uint32			tag;

	Assert(query != NULL);
	Assert(plan != NULL);

	if (!track_statements)
		return;
	if (!pgss)
		pgss_init();

	userid = GetUserId();
	dbid = MyDatabaseId;
	tag = pgss_tag(userid, dbid, query);

	/* Get the stats entry for this statement, create if necessary */
	LWLockAcquire(pgss->lock, LW_SHARED);

	entry = hash_search(pgss_hash, &tag, HASH_FIND, NULL);
	if (!entry)
	{
		bool		found;

		/* Re-acquire exclusive lock to add a new entry. */
		LWLockRelease(pgss->lock);
		LWLockAcquire(pgss->lock, LW_EXCLUSIVE);

		entry = entry_alloc(tag, userid, dbid, query,
						   plan->planTree->total_cost, &found);
		if (!found)
		{
			LWLockRelease(pgss->lock);
			return;	/* Initialized already. */
		}
	}

	SpinLockAcquire(&entry->mutex);
	entry->planned += 1;
	entry->total_cost += plan->planTree->total_cost;
	entry->usage += USAGE_PLANNED;
	SpinLockRelease(&entry->mutex);

	LWLockRelease(pgss->lock);
}

static void
pgss_executed(const char *query, PlannedStmt *plan, instr_time duration,
			  long gets, long reads, long writes)
{
	volatile Entry *entry;
	Oid				userid;
	Oid				dbid;
	uint32			tag;

	Assert(query != NULL);
	Assert(plan != NULL);

	if (!track_statements)
		return;
	if (!pgss)
		pgss_init();

	userid = GetUserId();
	dbid = MyDatabaseId;
	tag = pgss_tag(userid, dbid, query);

	LWLockAcquire(pgss->lock, LW_SHARED);

	entry = hash_search(pgss_hash, &tag, HASH_FIND, NULL);
	if (!entry)
	{
		bool found;

		/* Re-acquire exclusive lock to add a new entry. */
		LWLockRelease(pgss->lock);
		LWLockAcquire(pgss->lock, LW_EXCLUSIVE);

		entry = entry_alloc(tag, userid, dbid, query,
							plan->planTree->total_cost, &found);
	}

	SpinLockAcquire(&entry->mutex);
	INSTR_TIME_ADD(entry->total_time, duration);
	entry->calls += 1;
	entry->gets += gets;
	entry->reads += reads;
	entry->writes += writes;
	entry->usage += USAGE_EXECUTED;
	SpinLockRelease(&entry->mutex);

	LWLockRelease(pgss->lock);
}

static PlannedStmt *
pgss_planner(Query *parse, int cursorOptions, ParamListInfo boundParams)
{
	PlannedStmt	   *result;
	const char	   *query = debug_query_string;

	Assert(prev_planner_hook != NULL);

	if (pgss_toplevel && track_statements && query)
	{

#ifdef PGSS_DEBUG
		elog(NOTICE, "PLAN: %s", query);
#endif

		/* Disable our hooks temporarily during the top-level query. */
		pgss_toplevel = false;
		PG_TRY();
		{
			result = prev_planner_hook(parse, cursorOptions, boundParams);
			pgss_planned(query, result);
		}
		PG_CATCH();
		{
			pgss_toplevel = true;
			PG_RE_THROW();
		}
		PG_END_TRY();
		pgss_toplevel = true;
	}
	else
	{
		elog(NOTICE, "PLAN(ignore): %s", query ? query : "(NULL)");

		/* ignore recursive plannings, that are typically function calls */
		result = prev_planner_hook(parse, cursorOptions, boundParams);
	}

	return result;
}

TupleTableSlot *
pgss_ExecutorRun(QueryDesc *queryDesc, ScanDirection direction, long count)
{
	TupleTableSlot *result;
	const char	   *query = (ActivePortal ? ActivePortal->sourceText : NULL);
	
	Assert(prev_ExecutorRun_hook != NULL);

	if (pgss_toplevel && track_statements && query)
	{
		instr_time		starttime;
		instr_time		duration;
		long			gets;
		long			hits;
		long			writes;

#ifdef PGSS_DEBUG
		elog(NOTICE, "EXEC: %s", query);
#endif

		/* Disable our hooks temporarily during the top-level query. */
		pgss_toplevel = false;
		PG_TRY();
		{
			/* Reset buffer stats if needed. */
			if (!log_executor_stats && !log_statement_stats)
				ResetBufferUsage();
			INSTR_TIME_SET_CURRENT(starttime);

			result = prev_ExecutorRun_hook(queryDesc, direction, count);

			INSTR_TIME_SET_CURRENT(duration);
			INSTR_TIME_SUBTRACT(duration, starttime);
			gets = ReadBufferCount + ReadLocalBufferCount;
			hits = BufferHitCount + LocalBufferHitCount;
			writes = BufferFlushCount + LocalBufferFlushCount;
			pgss_executed(query, queryDesc->plannedstmt, duration,
				gets, gets - hits, writes);
		}
		PG_CATCH();
		{
			pgss_toplevel = true;
			PG_RE_THROW();
		}
		PG_END_TRY();
		pgss_toplevel = true;
	}
	else
	{
		elog(NOTICE, "EXEC(ignore): %s", query ? query : "(NULL)");

		/* ignore recursive executions, that are typically function calls */
		result = prev_ExecutorRun_hook(queryDesc, direction, count);
	}

	return result;
}

Datum
pg_stat_statements_reset(PG_FUNCTION_ARGS)
{
	if (pgss)
	{
		HASH_SEQ_STATUS		hash_seq;
		Entry			   *entry;

		LWLockAcquire(pgss->lock, LW_EXCLUSIVE);

		hash_seq_init(&hash_seq, pgss_hash);
		while ((entry = hash_seq_search(&hash_seq)) != NULL)
		{
			SpinLockFree(&entry->mutex);
			hash_search(pgss_hash, &entry->tag, HASH_REMOVE, NULL);
		}

		pgss->num_statements = 0;
		chunk_reset();

		LWLockRelease(pgss->lock);
	}

	PG_RETURN_BOOL(true);
}

Datum
pg_stat_statements(PG_FUNCTION_ARGS)
{
	ReturnSetInfo	   *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
	TupleDesc			tupdesc;
	Tuplestorestate    *tupstore;
	MemoryContext		per_query_ctx;
	MemoryContext		oldcontext;
	Oid					userid = GetUserId();
	bool				is_superuser = superuser();
	HASH_SEQ_STATUS		hash_seq;
	Entry			   *entry;

	if (!pgss)
		pgss_init();

	/* check to see if caller supports us returning a tuplestore */
	if (rsinfo == NULL || !IsA(rsinfo, ReturnSetInfo))
		ereport(ERROR,
				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
				 errmsg("set-valued function called in context that cannot accept a set")));
	if (!(rsinfo->allowedModes & SFRM_Materialize))
		ereport(ERROR,
				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
				 errmsg("materialize mode required, but it is not " \
						"allowed in this context")));

	per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
	oldcontext = MemoryContextSwitchTo(per_query_ctx);

	tupdesc = CreateTupleDescCopy(rsinfo->expectedDesc);
	tupstore = tuplestore_begin_heap(true, false, work_mem);
	rsinfo->returnMode = SFRM_Materialize;
	rsinfo->setResult = tupstore;
	rsinfo->setDesc = tupdesc;

	LWLockAcquire(pgss->lock, LW_SHARED);

	hash_seq_init(&hash_seq, pgss_hash);
	while ((entry = hash_seq_search(&hash_seq)) != NULL)
	{
		Datum	values[10];
		bool	nulls[10] = { 0 };
		int		i = 0;

		/* generate junk in short-term context */
		MemoryContextSwitchTo(oldcontext);

		values[i++] = ObjectIdGetDatum(entry->userid);
		values[i++] = ObjectIdGetDatum(entry->dbid);

		if (is_superuser || entry->userid == userid)
			values[i++] = CStringGetTextDatum(entry->chunk->query);
		else
			values[i++] = CStringGetTextDatum("<insufficient privilege>");

		SpinLockAcquire(&entry->mutex);
		values[i++] = Int64GetDatumFast(entry->planned);
		values[i++] = Int64GetDatumFast(entry->calls);
		values[i++] = Float8GetDatumFast(entry->total_cost);
		values[i++] = Int64GetDatum(INSTR_TIME_GET_MICROSEC(entry->total_time));
		values[i++] = Int64GetDatumFast(entry->gets);
		values[i++] = Int64GetDatumFast(entry->reads);
		values[i++] = Int64GetDatumFast(entry->writes);
		SpinLockRelease(&entry->mutex);

		/* switch to appropriate context while storing the tuple */
		MemoryContextSwitchTo(per_query_ctx);
		tuplestore_putvalues(tupstore, tupdesc, values, nulls);
	}

	LWLockRelease(pgss->lock);

	/* clean up and return the tuplestore */
	tuplestore_donestoring(tupstore);

	MemoryContextSwitchTo(oldcontext);

	return (Datum) 0;
}

/*---- Memory management functions ----*/

static Size
pgss_memsize(void)
{
	if (max_statements == 0)
	{
		DefineCustomIntVariable(
			PGSS_GUC("max_statements"),
			"Sets the maximum number of statements tracked by pg_stat_statements.",
			NULL,
			&max_statements, 100, INT_MAX,
			PGC_USERSET, max_statements_assign, NULL);

		/* use default if not specified. */
		if (max_statements == 0)
			max_statements = MAX_STATEMENTS_DEFAULT;

		DefineCustomBoolVariable(
			PGSS_GUC("track_statements"),
			"Collects information about executed statements.",
			NULL,
			&track_statements,
			PGC_USERSET, track_statements_assign, NULL);
	}

	return offsetof(pgStatStmt, chunks) +
		mul_size(max_statements, CHUNK_SIZE) + 
		hash_estimate_size(max_statements, sizeof(Entry));
}

static void
pgss_init(void)
{
    bool		found;
	Size		size;
	HASHCTL		info = { 0 };

	Assert(pgss == NULL);

	size = pgss_memsize();

	LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);

	pgss = ShmemInitStruct("pg_stat_statements", size, &found);
	if (!pgss)
		elog(ERROR, "out of shared memory");

	if (!found)
    {
		pgss->lock = LWLockAssign();
		pgss->num_statements = 0;
		pgss->chunk_size = CHUNK_SIZE;
		chunk_reset();
	}

	info.keysize = sizeof(uint32);
	info.entrysize = sizeof(Entry);
	info.hash = oid_hash;
	pgss_hash = ShmemInitHash("pg_stat_statements hash",
							  max_statements, max_statements,
							  &info,
							  HASH_ELEM | HASH_FUNCTION);
	if (!pgss_hash)
		elog(ERROR, "out of shared memory");

	LWLockRelease(AddinShmemInitLock);
}

/*
 * caller must be hold an exclusive lock on pgStateStmt->lock
 */
static volatile Entry *
entry_alloc(uint32 tag, Oid userid, Oid dbid, const char *query,
			double cost, bool *found)
{
	volatile Entry *entry;

	Assert(query != NULL);
	Assert(found != NULL);

	while (pgss->num_statements >= max_statements)
		entry_dealloc();

	entry = hash_search(pgss_hash, &tag, HASH_ENTER, found);

	if (! *found)
	{
		Size	len;

		pgss->num_statements++;

		entry->userid = userid;
		entry->dbid = dbid;

		entry->chunk = chunk_alloc();
		len = strlen(query);
		len = pg_mbcliplen(query, len, pgss->chunk_size - 1);
		memcpy(entry->chunk->query, query, len);
		entry->chunk->query[len] = '\0';

		entry->planned = 1;
		entry->calls = 0;
		entry->total_cost = cost;
		entry->usage = USAGE_INITIAL;
		INSTR_TIME_SET_ZERO(entry->total_time);
		SpinLockInit(&entry->mutex);
	}

	return entry;
}

/*
 * Free least recently used entries.
 * Caller must be hold an exclusive lock on pgStateStmt->lock.
 */
static int
cmp_usage(const void *lhs, const void *rhs)
{
	double	l_usage = (*(const Entry **)lhs)->usage;
	double	r_usage = (*(const Entry **)rhs)->usage;

	if (l_usage < r_usage)
		return -1;
	else if (l_usage > r_usage)
		return +1;
	else
		return 0;
}

/*
 * Deallocate least used entries.
 */
static void
entry_dealloc(void)
{
	HASH_SEQ_STATUS		hash_seq;
	Entry			  **entries;
	Entry			   *entry;
	int					nvictims;
	int					i;

	/* Sort entries by usage and deallocate USAGE_DEALLOC_PERCENT of them. */

	entries = palloc(pgss->num_statements * sizeof(Entry *));

	i = 0;
	hash_seq_init(&hash_seq, pgss_hash);
	while ((entry = hash_seq_search(&hash_seq)) != NULL)
	{
		entries[i++] = entry;
		entry->usage *= USAGE_DECREASE_FACTOR;
	}

	qsort(entries, i, sizeof(Entry *), cmp_usage);
	nvictims = Max(10, i * USAGE_DEALLOC_PERCENT / 100);

	for (i = 0; i < nvictims; i++)
	{
		chunk_free(entries[i]->chunk);
		SpinLockFree(&entries[i]->mutex);
		hash_search(pgss_hash, &entries[i]->tag, HASH_REMOVE, NULL);
		pgss->num_statements--;
	}

	pfree(entries);
}

static Chunk *
chunk_alloc(void)
{
	Chunk *chunk;

	Assert(pgss->free_chunks != NULL);

	chunk = pgss->free_chunks;
	pgss->free_chunks = pgss->free_chunks->next;
	return chunk;
}

static void
chunk_free(Chunk *chunk)
{
	Assert(chunk != NULL);

	chunk->next = pgss->free_chunks;
	pgss->free_chunks = chunk;
}

static void
chunk_reset(void)
{
	Chunk  *chunk;
	Chunk  *next;
	int				i;

	chunk = pgss->free_chunks = pgss->chunks;
	for (i = 0; i < max_statements - 1; i++)
	{
		next = (Chunk *) ((char *)chunk + pgss->chunk_size);
		chunk->next = next;
		chunk = next;
	}
	chunk->next = NULL;
}

/*---- GUC functions ----*/

static bool
max_statements_assign(int newval, bool doit, GucSource source)
{
	if (!doit)
		return true;

	/* Emulate PGC_POSTMASTER */
	if (max_statements == 0)
		return true;
	if (max_statements != newval)
	{
		ereport(GUC_complaint_elevel(source),
				(errcode(ERRCODE_CANT_CHANGE_RUNTIME_PARAM),
				 errmsg("parameter \"%s\" cannot be changed after server start",
						PGSS_GUC("max_statements"))));
		return false;
	}

	return true;
}

static bool
track_statements_assign(bool newval, bool doit, GucSource source)
{
	/* Emulate PGC_SUSET */
	if (source >= PGC_S_CLIENT && !superuser())
	{
		ereport(GUC_complaint_elevel(source),
				(errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
				 errmsg("permission denied to set parameter \"%s\"",
						PGSS_GUC("track_statements"))));
		return false;
	}

	return true;
}
