>>>>> "Merlin" == Merlin Moncure <mmonc...@gmail.com> writes:

 Merlin> Note, the biggest pain point I have with COPY is not being able
 Merlin> to parameterize the filename argument.

Second proof of concept attached. This goes so far as to allow
statements like:

do $$
  declare t text := 'bar'; f text := '/tmp/copytest.dat';
  begin copy (select t, now()) to (f) csv header; end;
$$;

Also "copy foo to $1" or "copy (select * from foo where x=$1) to $2" and
so on should work from PQexecParams or in a plpgsql EXECUTE.

(I haven't tried to parameterize anything other than the filename and
query. Also, it does not accept arbitrary expressions - only $n, '...'
or a columnref. $n and '...' can have parens or not, but the columnref
must have them due to conflicts with unreserved keywords PROGRAM, STDIN,
STDOUT. This could be hacked around in other ways, I guess, if the
parens are too ugly.)

-- 
Andrew (irc:RhodiumToad)

diff --git a/src/backend/commands/copy.c b/src/backend/commands/copy.c
index 3201476..97debb7 100644
--- a/src/backend/commands/copy.c
+++ b/src/backend/commands/copy.c
@@ -37,6 +37,8 @@
 #include "optimizer/clauses.h"
 #include "optimizer/planner.h"
 #include "nodes/makefuncs.h"
+#include "nodes/nodeFuncs.h"
+#include "parser/analyze.h"
 #include "rewrite/rewriteHandler.h"
 #include "storage/fd.h"
 #include "tcop/tcopprot.h"
@@ -279,13 +281,13 @@ static const char BinarySignature[11] = "PGCOPY\n\377\r\n\0";
 
 
 /* non-export function prototypes */
-static CopyState BeginCopy(bool is_from, Relation rel, Node *raw_query,
+static CopyState BeginCopy(bool is_from, Relation rel, Node *raw_query, ParamListInfo params,
 		  const char *queryString, const Oid queryRelId, List *attnamelist,
 		  List *options);
 static void EndCopy(CopyState cstate);
 static void ClosePipeToProgram(CopyState cstate);
-static CopyState BeginCopyTo(Relation rel, Node *query, const char *queryString,
-			const Oid queryRelId, const char *filename, bool is_program,
+static CopyState BeginCopyTo(Relation rel, Node *query, ParamListInfo params, const char *queryString,
+			const Oid queryRelId, Node *filename_expr, bool is_program,
 			List *attnamelist, List *options);
 static void EndCopyTo(CopyState cstate);
 static uint64 DoCopyTo(CopyState cstate);
@@ -767,6 +769,43 @@ CopyLoadRawBuf(CopyState cstate)
 }
 
 
+static char *
+CopyEvalFilename(QueryDesc *qd, Node *expr, ParamListInfo params)
+{
+	char *filename = NULL;
+
+	if (expr)
+	{
+		EState *estate = qd ? qd->estate : CreateExecutorState();
+		MemoryContext oldcontext = MemoryContextSwitchTo(estate->es_query_cxt);
+
+		Assert(exprType(expr) == CSTRINGOID);
+
+		if (qd == NULL)
+			estate->es_param_list_info = params;
+
+		{
+			ExprContext *ecxt = CreateExprContext(estate);
+			ExprState *exprstate = ExecPrepareExpr(copyObject(expr), estate);
+			bool isnull;
+			Datum result = ExecEvalExprSwitchContext(exprstate, ecxt, &isnull, NULL);
+			if (!isnull)
+				filename = MemoryContextStrdup(oldcontext, DatumGetCString(result));
+			FreeExprContext(ecxt, true);
+		}
+
+		MemoryContextSwitchTo(oldcontext);
+
+		if (qd == NULL)
+			FreeExecutorState(estate);
+
+		if (!filename)
+			elog(ERROR, "COPY filename expression must not return NULL");
+	}
+
+	return filename;
+}
+
 /*
  *	 DoCopy executes the SQL COPY statement
  *
@@ -787,7 +826,7 @@ CopyLoadRawBuf(CopyState cstate)
  * the table or the specifically requested columns.
  */
 Oid
-DoCopy(const CopyStmt *stmt, const char *queryString, uint64 *processed)
+DoCopy(const CopyStmt *stmt, const char *queryString, ParamListInfo params, uint64 *processed)
 {
 	CopyState	cstate;
 	bool		is_from = stmt->is_from;
@@ -906,7 +945,7 @@ DoCopy(const CopyStmt *stmt, const char *queryString, uint64 *processed)
 			select->targetList = list_make1(target);
 			select->fromClause = list_make1(from);
 
-			query = (Node *) select;
+			query = (Node *) parse_analyze((Node *) select, queryString, NULL, 0);
 
 			/*
 			 * Close the relation for now, but keep the lock on it to prevent
@@ -929,6 +968,8 @@ DoCopy(const CopyStmt *stmt, const char *queryString, uint64 *processed)
 
 	if (is_from)
 	{
+		char *filename;
+
 		Assert(rel);
 
 		/* check read-only transaction and parallel mode */
@@ -936,15 +977,20 @@ DoCopy(const CopyStmt *stmt, const char *queryString, uint64 *processed)
 			PreventCommandIfReadOnly("COPY FROM");
 		PreventCommandIfParallelMode("COPY FROM");
 
-		cstate = BeginCopyFrom(rel, stmt->filename, stmt->is_program,
+		filename = CopyEvalFilename(NULL, stmt->filename, params);
+
+		cstate = BeginCopyFrom(rel, filename, stmt->is_program,
 							   stmt->attlist, stmt->options);
 		cstate->range_table = range_table;
 		*processed = CopyFrom(cstate);	/* copy from file to database */
 		EndCopyFrom(cstate);
+
+		if (filename)
+			pfree(filename);
 	}
 	else
 	{
-		cstate = BeginCopyTo(rel, query, queryString, relid,
+		cstate = BeginCopyTo(rel, query, params, queryString, relid,
 							 stmt->filename, stmt->is_program,
 							 stmt->attlist, stmt->options);
 		*processed = DoCopyTo(cstate);	/* copy from database to file */
@@ -1321,6 +1367,7 @@ static CopyState
 BeginCopy(bool is_from,
 		  Relation rel,
 		  Node *raw_query,
+		  ParamListInfo params,
 		  const char *queryString,
 		  const Oid queryRelId,
 		  List *attnamelist,
@@ -1391,8 +1438,7 @@ BeginCopy(bool is_from,
 		 * function and is executed repeatedly.  (See also the same hack in
 		 * DECLARE CURSOR and PREPARE.)  XXX FIXME someday.
 		 */
-		rewritten = pg_analyze_and_rewrite((Node *) copyObject(raw_query),
-										   queryString, NULL, 0);
+		rewritten = QueryRewrite(copyObject(raw_query));
 
 		/* check that we got back something we can work with */
 		if (rewritten == NIL)
@@ -1453,7 +1499,7 @@ BeginCopy(bool is_from,
 		}
 
 		/* plan the query */
-		plan = pg_plan_query(query, 0, NULL);
+		plan = pg_plan_query(query, 0, params);
 
 		/*
 		 * With row level security and a user using "COPY relation TO", we
@@ -1495,7 +1541,7 @@ BeginCopy(bool is_from,
 		cstate->queryDesc = CreateQueryDesc(plan, queryString,
 											GetActiveSnapshot(),
 											InvalidSnapshot,
-											dest, NULL, 0);
+											dest, params, 0);
 
 		/*
 		 * Call ExecutorStart to prepare the plan for execution.
@@ -1682,15 +1728,16 @@ EndCopy(CopyState cstate)
 static CopyState
 BeginCopyTo(Relation rel,
 			Node *query,
+			ParamListInfo params,
 			const char *queryString,
 			const Oid queryRelId,
-			const char *filename,
+			Node *filename_expr,
 			bool is_program,
 			List *attnamelist,
 			List *options)
 {
 	CopyState	cstate;
-	bool		pipe = (filename == NULL);
+	bool		pipe = (filename_expr == NULL);
 	MemoryContext oldcontext;
 
 	if (rel != NULL && rel->rd_rel->relkind != RELKIND_RELATION)
@@ -1725,7 +1772,7 @@ BeginCopyTo(Relation rel,
 							RelationGetRelationName(rel))));
 	}
 
-	cstate = BeginCopy(false, rel, query, queryString, queryRelId, attnamelist,
+	cstate = BeginCopy(false, rel, query, params, queryString, queryRelId, attnamelist,
 					   options);
 	oldcontext = MemoryContextSwitchTo(cstate->copycontext);
 
@@ -1737,7 +1784,7 @@ BeginCopyTo(Relation rel,
 	}
 	else
 	{
-		cstate->filename = pstrdup(filename);
+		cstate->filename = CopyEvalFilename(cstate->queryDesc, filename_expr, params);
 		cstate->is_program = is_program;
 
 		if (is_program)
@@ -1758,7 +1805,7 @@ BeginCopyTo(Relation rel,
 			 * Prevent write to relative path ... too easy to shoot oneself in
 			 * the foot by overwriting a database file ...
 			 */
-			if (!is_absolute_path(filename))
+			if (!is_absolute_path(cstate->filename))
 				ereport(ERROR,
 						(errcode(ERRCODE_INVALID_NAME),
 					  errmsg("relative path not allowed for COPY to file")));
@@ -2670,7 +2717,7 @@ BeginCopyFrom(Relation rel,
 	MemoryContext oldcontext;
 	bool		volatile_defexprs;
 
-	cstate = BeginCopy(true, rel, NULL, NULL, InvalidOid, attnamelist, options);
+	cstate = BeginCopy(true, rel, NULL, NULL, NULL, InvalidOid, attnamelist, options);
 	oldcontext = MemoryContextSwitchTo(cstate->copycontext);
 
 	/* Initialize state variables */
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index 20e38f0..f78404e 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -2983,7 +2983,7 @@ _copyCopyStmt(const CopyStmt *from)
 	COPY_NODE_FIELD(attlist);
 	COPY_SCALAR_FIELD(is_from);
 	COPY_SCALAR_FIELD(is_program);
-	COPY_STRING_FIELD(filename);
+	COPY_NODE_FIELD(filename);
 	COPY_NODE_FIELD(options);
 
 	return newnode;
diff --git a/src/backend/nodes/equalfuncs.c b/src/backend/nodes/equalfuncs.c
index c5ccc42..4af21cb 100644
--- a/src/backend/nodes/equalfuncs.c
+++ b/src/backend/nodes/equalfuncs.c
@@ -1144,7 +1144,7 @@ _equalCopyStmt(const CopyStmt *a, const CopyStmt *b)
 	COMPARE_NODE_FIELD(attlist);
 	COMPARE_SCALAR_FIELD(is_from);
 	COMPARE_SCALAR_FIELD(is_program);
-	COMPARE_STRING_FIELD(filename);
+	COMPARE_NODE_FIELD(filename);
 	COMPARE_NODE_FIELD(options);
 
 	return true;
diff --git a/src/backend/parser/analyze.c b/src/backend/parser/analyze.c
index 29c8c4e..f0a3f60 100644
--- a/src/backend/parser/analyze.c
+++ b/src/backend/parser/analyze.c
@@ -36,12 +36,14 @@
 #include "parser/parse_coerce.h"
 #include "parser/parse_collate.h"
 #include "parser/parse_cte.h"
+#include "parser/parse_expr.h"
 #include "parser/parse_oper.h"
 #include "parser/parse_param.h"
 #include "parser/parse_relation.h"
 #include "parser/parse_target.h"
 #include "parser/parsetree.h"
 #include "rewrite/rewriteManip.h"
+#include "utils/builtins.h"
 #include "utils/rel.h"
 
 
@@ -74,6 +76,7 @@ static Query *transformCreateTableAsStmt(ParseState *pstate,
 						   CreateTableAsStmt *stmt);
 static void transformLockingClause(ParseState *pstate, Query *qry,
 					   LockingClause *lc, bool pushedDown);
+static Query *transformCopyStmt(ParseState *pstate, CopyStmt *stmt);
 #ifdef RAW_EXPRESSION_COVERAGE_TEST
 static bool test_raw_expression_coverage(Node *node, void *context);
 #endif
@@ -290,6 +293,11 @@ transformStmt(ParseState *pstate, Node *parseTree)
 											(CreateTableAsStmt *) parseTree);
 			break;
 
+		case T_CopyStmt:
+			result = transformCopyStmt(pstate,
+										  (CopyStmt *) parseTree);
+			break;
+
 		default:
 
 			/*
@@ -347,6 +355,11 @@ analyze_requires_snapshot(Node *parseTree)
 			result = true;
 			break;
 
+		case T_CopyStmt:
+			/* maybe, because we might have a contained statement */
+			result = ((CopyStmt *)parseTree)->query != NULL;
+			break;
+
 		default:
 			/* other utility statements don't have any real parse analysis */
 			result = false;
@@ -356,6 +369,40 @@ analyze_requires_snapshot(Node *parseTree)
 	return result;
 }
 
+static Query *
+transformCopyStmt(ParseState *pstate, CopyStmt *stmt)
+{
+	Query *result = makeNode(Query);
+
+	result->commandType = CMD_UTILITY;
+	result->utilityStmt = (Node *) stmt;
+
+	if (stmt->filename)
+	{
+		Node *expr1 = transformExpr(pstate, stmt->filename, EXPR_KIND_OTHER);
+		Node *expr2 = coerce_to_target_type(pstate, expr1, exprType(expr1),
+											CSTRINGOID, -1,
+											COERCION_EXPLICIT,
+											COERCE_IMPLICIT_CAST,
+											exprLocation(expr1));
+
+		if (!expr2)
+			ereport(ERROR,
+				(errcode(ERRCODE_CANNOT_COERCE),
+				 errmsg("cannot cast type %s to %s",
+						format_type_be(exprType(expr1)),
+						format_type_be(CSTRINGOID)),
+				 parser_errposition(pstate, exprLocation(expr1))));
+
+		stmt->filename = expr2;
+	}
+
+	if (stmt->query != NULL)
+		stmt->query = (Node *) transformStmt(pstate, stmt->query);
+
+	return result;
+}
+
 /*
  * transformDeleteStmt -
  *	  transforms a Delete Statement
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 18ec5f0..6854255 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -313,8 +313,8 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
 %type <defelt>	event_trigger_when_item
 %type <chr>		enable_trigger
 
-%type <str>		copy_file_name
-				database_name access_method_clause access_method attr_name
+%type <node>	copy_file_name param_opt_indirection
+%type <str>		database_name access_method_clause access_method attr_name
 				name cursor_name file_name
 				index_name opt_index_name cluster_index_specification
 
@@ -2651,7 +2651,11 @@ opt_program:
  * stdout. We silently correct the "typo".)		 - AY 9/94
  */
 copy_file_name:
-			Sconst									{ $$ = $1; }
+			Sconst									{ $$ = makeStringConst($1,@1); }
+			| '(' Sconst ')'						{ $$ = makeStringConst($2,@2); }
+			| param_opt_indirection					{ $$ = $1; }
+			| '(' columnref ')'						{ $$ = $2; }
+			| '(' param_opt_indirection ')'			{ $$ = $2; }
 			| STDIN									{ $$ = NULL; }
 			| STDOUT								{ $$ = NULL; }
 		;
@@ -12049,21 +12053,7 @@ b_expr:		c_expr
  */
 c_expr:		columnref								{ $$ = $1; }
 			| AexprConst							{ $$ = $1; }
-			| PARAM opt_indirection
-				{
-					ParamRef *p = makeNode(ParamRef);
-					p->number = $1;
-					p->location = @1;
-					if ($2)
-					{
-						A_Indirection *n = makeNode(A_Indirection);
-						n->arg = (Node *) p;
-						n->indirection = check_indirection($2, yyscanner);
-						$$ = (Node *) n;
-					}
-					else
-						$$ = (Node *) p;
-				}
+			| param_opt_indirection					{ $$ = $1; }
 			| '(' a_expr ')' opt_indirection
 				{
 					if ($4)
@@ -12192,6 +12182,23 @@ c_expr:		columnref								{ $$ = $1; }
 			  }
 		;
 
+param_opt_indirection: PARAM opt_indirection
+				{
+					ParamRef *p = makeNode(ParamRef);
+					p->number = $1;
+					p->location = @1;
+					if ($2)
+					{
+						A_Indirection *n = makeNode(A_Indirection);
+						n->arg = (Node *) p;
+						n->indirection = check_indirection($2, yyscanner);
+						$$ = (Node *) n;
+					}
+					else
+						$$ = (Node *) p;
+				}
+		;
+
 func_application: func_name '(' ')'
 				{
 					$$ = (Node *) makeFuncCall($1, NIL, @1);
diff --git a/src/backend/tcop/utility.c b/src/backend/tcop/utility.c
index ac50c2a..dba2f5e 100644
--- a/src/backend/tcop/utility.c
+++ b/src/backend/tcop/utility.c
@@ -540,7 +540,7 @@ standard_ProcessUtility(Node *parsetree,
 			{
 				uint64		processed;
 
-				DoCopy((CopyStmt *) parsetree, queryString, &processed);
+				DoCopy((CopyStmt *) parsetree, queryString, params, &processed);
 				if (completionTag)
 					snprintf(completionTag, COMPLETION_TAG_BUFSIZE,
 							 "COPY " UINT64_FORMAT, processed);
@@ -1769,6 +1769,16 @@ UtilityContainsQuery(Node *parsetree)
 				return UtilityContainsQuery(qry->utilityStmt);
 			return qry;
 
+		case T_CopyStmt:
+			qry = (Query *) ((CopyStmt *) parsetree)->query;
+			if (qry)
+			{
+				Assert(IsA(qry, Query));
+				if (qry->commandType == CMD_UTILITY)
+					return UtilityContainsQuery(qry->utilityStmt);
+			}
+			return qry;
+
 		case T_CreateTableAsStmt:
 			qry = (Query *) ((CreateTableAsStmt *) parsetree)->query;
 			Assert(IsA(qry, Query));
diff --git a/src/include/commands/copy.h b/src/include/commands/copy.h
index 314d1f7..44c2c66 100644
--- a/src/include/commands/copy.h
+++ b/src/include/commands/copy.h
@@ -21,7 +21,7 @@
 /* CopyStateData is private in commands/copy.c */
 typedef struct CopyStateData *CopyState;
 
-extern Oid DoCopy(const CopyStmt *stmt, const char *queryString,
+extern Oid DoCopy(const CopyStmt *stmt, const char *queryString, ParamListInfo params,
 	   uint64 *processed);
 
 extern void ProcessCopyOptions(CopyState cstate, bool is_from, List *options);
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index 714cf15..049ac4a 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -1692,7 +1692,7 @@ typedef struct CopyStmt
 								 * for all columns */
 	bool		is_from;		/* TO or FROM */
 	bool		is_program;		/* is 'filename' a program to popen? */
-	char	   *filename;		/* filename, or NULL for STDIN/STDOUT */
+	Node	   *filename;		/* filename, or NULL for STDIN/STDOUT */
 	List	   *options;		/* List of DefElem nodes */
 } CopyStmt;
 
diff --git a/src/test/regress/expected/copyselect.out b/src/test/regress/expected/copyselect.out
index 72865fe..a02a199 100644
--- a/src/test/regress/expected/copyselect.out
+++ b/src/test/regress/expected/copyselect.out
@@ -44,7 +44,9 @@ c
 -- This should fail
 --
 copy (select t into temp test3 from test1 where id=3) to stdout;
-ERROR:  COPY (SELECT INTO) is not supported
+ERROR:  SELECT ... INTO is not allowed here
+LINE 1: copy (select t into temp test3 from test1 where id=3) to std...
+                                 ^
 --
 -- This should fail
 --
diff --git a/src/test/regress/expected/select_into.out b/src/test/regress/expected/select_into.out
index b577d1b..6b6482e 100644
--- a/src/test/regress/expected/select_into.out
+++ b/src/test/regress/expected/select_into.out
@@ -90,7 +90,9 @@ ERROR:  SELECT ... INTO is not allowed here
 LINE 1: DECLARE foo CURSOR FOR SELECT 1 INTO b;
                                              ^
 COPY (SELECT 1 INTO frak UNION SELECT 2) TO 'blob';
-ERROR:  COPY (SELECT INTO) is not supported
+ERROR:  SELECT ... INTO is not allowed here
+LINE 1: COPY (SELECT 1 INTO frak UNION SELECT 2) TO 'blob';
+                            ^
 SELECT * FROM (SELECT 1 INTO f) bar;
 ERROR:  SELECT ... INTO is not allowed here
 LINE 1: SELECT * FROM (SELECT 1 INTO f) bar;
-- 
Sent via pgsql-hackers mailing list (pgsql-hackers@postgresql.org)
To make changes to your subscription:
http://www.postgresql.org/mailpref/pgsql-hackers

Reply via email to