This is an automated email from the ASF dual-hosted git repository.

mtaha pushed a commit to branch PG16
in repository https://gitbox.apache.org/repos/asf/age.git


The following commit(s) were added to refs/heads/PG16 by this push:
     new 989ecd3e Add support for fuzzystrmatch and other external extensions 
(#2083) (#2096)
989ecd3e is described below

commit 989ecd3efeacc0618e5b51aa0bc0e45ab8abd6ec
Author: John Gemignani <[email protected]>
AuthorDate: Wed Sep 11 07:33:16 2024 -0700

    Add support for fuzzystrmatch and other external extensions (#2083) (#2096)
    
    Added support for the fuzzystrmatch extension and for other
    external extensions.
    
    Added support for typecasting `::pg_text` to TEXTOID.
    
    Added regression tests.
    
       modified:   sql/agtype_coercions.sql
       modified:   src/backend/parser/cypher_expr.c
       modified:   src/backend/utils/adt/agtype.c
       modified:   regress/expected/expr.out
       modified:   regress/sql/expr.sql
---
 regress/expected/expr.out        |  29 ++++
 regress/sql/expr.sql             |   9 ++
 sql/agtype_coercions.sql         |  25 +++-
 src/backend/parser/cypher_expr.c | 305 ++++++++++++++++++++++++++++++++++++++-
 src/backend/utils/adt/agtype.c   |  54 ++++++-
 5 files changed, 413 insertions(+), 9 deletions(-)

diff --git a/regress/expected/expr.out b/regress/expected/expr.out
index 616c8108..0aec6930 100644
--- a/regress/expected/expr.out
+++ b/regress/expected/expr.out
@@ -8767,9 +8767,38 @@ SELECT * FROM cypher('issue_1988', $$
  {"id": 844424930131969, "label": "Part", "properties": {"set": "set", 
"match": "match", "merge": "merge", "create": "create", "delete": "delete", 
"part_num": 123}}::vertex
 (4 rows)
 
+--
+-- Test external extension function logic for fuzzystrmatch
+--
+SELECT * FROM create_graph('fuzzystrmatch');
+NOTICE:  graph "fuzzystrmatch" has been created
+ create_graph 
+--------------
+ 
+(1 row)
+
+-- These should fail with extension not installed
+SELECT * FROM cypher('fuzzystrmatch', $$ RETURN soundex("hello world!") $$) AS 
(result agtype);
+ERROR:  extension fuzzystrmatch is not installed for function soundex
+LINE 1: SELECT * FROM cypher('fuzzystrmatch', $$ RETURN soundex("hel...
+                                               ^
+SELECT * FROM cypher('fuzzystrmatch', $$ RETURN difference("hello world!", 
"hello world!") $$) AS (result agtype);
+ERROR:  extension fuzzystrmatch is not installed for function difference
+LINE 1: SELECT * FROM cypher('fuzzystrmatch', $$ RETURN difference("...
+                                               ^
 --
 -- Cleanup
 --
+SELECT * FROM drop_graph('fuzzystrmatch', true);
+NOTICE:  drop cascades to 2 other objects
+DETAIL:  drop cascades to table fuzzystrmatch._ag_label_vertex
+drop cascades to table fuzzystrmatch._ag_label_edge
+NOTICE:  graph "fuzzystrmatch" has been dropped
+ drop_graph 
+------------
+ 
+(1 row)
+
 SELECT * FROM drop_graph('issue_1988', true);
 NOTICE:  drop cascades to 4 other objects
 DETAIL:  drop cascades to table issue_1988._ag_label_vertex
diff --git a/regress/sql/expr.sql b/regress/sql/expr.sql
index 7aae42c7..cc96b843 100644
--- a/regress/sql/expr.sql
+++ b/regress/sql/expr.sql
@@ -3536,9 +3536,18 @@ SELECT * FROM cypher('issue_1988', $$
 SELECT * FROM cypher('issue_1988', $$
     MATCH (p) RETURN p $$) as (p agtype);
 
+--
+-- Test external extension function logic for fuzzystrmatch
+--
+SELECT * FROM create_graph('fuzzystrmatch');
+-- These should fail with extension not installed
+SELECT * FROM cypher('fuzzystrmatch', $$ RETURN soundex("hello world!") $$) AS 
(result agtype);
+SELECT * FROM cypher('fuzzystrmatch', $$ RETURN difference("hello world!", 
"hello world!") $$) AS (result agtype);
+
 --
 -- Cleanup
 --
+SELECT * FROM drop_graph('fuzzystrmatch', true);
 SELECT * FROM drop_graph('issue_1988', true);
 SELECT * FROM drop_graph('issue_1953', true);
 SELECT * FROM drop_graph('expanded_map', true);
diff --git a/sql/agtype_coercions.sql b/sql/agtype_coercions.sql
index 745190fd..bdc33af8 100644
--- a/sql/agtype_coercions.sql
+++ b/sql/agtype_coercions.sql
@@ -32,6 +32,15 @@ AS 'MODULE_PATHNAME';
 CREATE CAST (agtype AS text)
     WITH FUNCTION ag_catalog.agtype_to_text(agtype);
 
+-- text -> agtype
+CREATE FUNCTION ag_catalog.text_to_agtype(text)
+    RETURNS agtype
+    LANGUAGE c
+    IMMUTABLE
+RETURNS NULL ON NULL INPUT
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
+
 -- agtype -> boolean (implicit)
 CREATE FUNCTION ag_catalog.agtype_to_bool(agtype)
     RETURNS boolean
@@ -69,7 +78,7 @@ AS 'MODULE_PATHNAME';
 CREATE CAST (float8 AS agtype)
     WITH FUNCTION ag_catalog.float8_to_agtype(float8);
 
--- agtype -> float8 (implicit)
+-- agtype -> float8 (exmplicit)
 CREATE FUNCTION ag_catalog.agtype_to_float8(agtype)
     RETURNS float8
     LANGUAGE c
@@ -106,6 +115,18 @@ CREATE CAST (agtype AS bigint)
     WITH FUNCTION ag_catalog.agtype_to_int8(variadic "any")
 AS ASSIGNMENT;
 
+-- int4 -> agtype (explicit)
+CREATE FUNCTION ag_catalog.int4_to_agtype(int4)
+    RETURNS agtype
+    LANGUAGE c
+    IMMUTABLE
+RETURNS NULL ON NULL INPUT
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
+
+CREATE CAST (int4 AS agtype)
+    WITH FUNCTION ag_catalog.int4_to_agtype(int4);
+
 -- agtype -> int4
 CREATE FUNCTION ag_catalog.agtype_to_int4(variadic "any")
     RETURNS int
@@ -151,4 +172,4 @@ PARALLEL SAFE
 AS 'MODULE_PATHNAME';
 
 CREATE CAST (agtype AS json)
-    WITH FUNCTION ag_catalog.agtype_to_json(agtype);
\ No newline at end of file
+    WITH FUNCTION ag_catalog.agtype_to_json(agtype);
diff --git a/src/backend/parser/cypher_expr.c b/src/backend/parser/cypher_expr.c
index cbd72004..654223c9 100644
--- a/src/backend/parser/cypher_expr.c
+++ b/src/backend/parser/cypher_expr.c
@@ -24,6 +24,7 @@
 
 #include "postgres.h"
 
+#include "catalog/pg_proc.h"
 #include "miscadmin.h"
 #include "nodes/nodeFuncs.h"
 #include "optimizer/optimizer.h"
@@ -34,6 +35,7 @@
 #include "parser/parse_oper.h"
 #include "parser/parse_relation.h"
 #include "utils/builtins.h"
+#include "utils/catcache.h"
 #include "utils/float.h"
 #include "utils/lsyscache.h"
 
@@ -52,6 +54,7 @@
 #define FUNC_AGTYPE_TYPECAST_PG_FLOAT8 "agtype_to_float8"
 #define FUNC_AGTYPE_TYPECAST_PG_BIGINT "agtype_to_int8"
 #define FUNC_AGTYPE_TYPECAST_BOOL "agtype_typecast_bool"
+#define FUNC_AGTYPE_TYPECAST_PG_TEXT "agtype_to_text"
 
 static Node *transform_cypher_expr_recurse(cypher_parsestate *cpstate,
                                            Node *expr);
@@ -94,6 +97,14 @@ static Node 
*transform_column_ref_for_indirection(cypher_parsestate *cpstate,
                                                   ColumnRef *cr);
 static Node *transform_cypher_list_comprehension(cypher_parsestate *cpstate,
                                                  cypher_unwind *expr);
+static bool is_fuzzystrmatch_function(FuncCall *fn);
+static void check_for_extension_functions(char *extension, FuncCall *fn);
+static List *cast_agtype_input_to_other_type(cypher_parsestate *cpstate,
+                                             FuncCall *fn, List *targs);
+static Node *cast_input_to_output_type(cypher_parsestate *cpstate, Node *expr,
+                                       Oid source_oid, Oid target_oid);
+static Node *wrap_text_output_to_agtype(cypher_parsestate *cpstate,
+                                        FuncExpr *fexpr);
 
 /* transform a cypher expression */
 Node *transform_cypher_expr(cypher_parsestate *cpstate, Node *expr,
@@ -1580,11 +1591,16 @@ static Node 
*transform_cypher_typecast(cypher_parsestate *cpstate,
     {
         fname = lappend(fname, makeString(FUNC_AGTYPE_TYPECAST_PG_BIGINT));
     }
-    else if ((pg_strcasecmp(ctypecast->typecast, "bool") == 0 || 
+    else if ((pg_strcasecmp(ctypecast->typecast, "bool") == 0 ||
              pg_strcasecmp(ctypecast->typecast, "boolean") == 0))
     {
         fname = lappend(fname, makeString(FUNC_AGTYPE_TYPECAST_BOOL));
     }
+    else if (pg_strcasecmp(ctypecast->typecast, "pg_text") == 0)
+    {
+        fname = lappend(fname, makeString(FUNC_AGTYPE_TYPECAST_PG_TEXT));
+    }
+
     /* if none was found, error out */
     else
     {
@@ -1601,6 +1617,221 @@ static Node 
*transform_cypher_typecast(cypher_parsestate *cpstate,
     return transform_FuncCall(cpstate, fnode);
 }
 
+/* is the function part of the fuzzystrmatch extension */
+static bool is_fuzzystrmatch_function(FuncCall *fn)
+{
+    char *funcname = (((String*)linitial(fn->funcname))->sval);
+
+    if (pg_strcasecmp(funcname, "soundex") == 0 ||
+        pg_strcasecmp(funcname, "difference") == 0 ||
+        pg_strcasecmp(funcname, "daitch_mokotoff") == 0 ||
+        pg_strcasecmp(funcname, "soundex_tsvector") == 0 ||
+        pg_strcasecmp(funcname, "levenshtein") == 0 ||
+        pg_strcasecmp(funcname, "levenshtein_less_equal") == 0 ||
+        pg_strcasecmp(funcname, "metaphone") == 0 ||
+        pg_strcasecmp(funcname, "dmetaphone") == 0)
+    {
+        return true;
+    }
+    return false;
+}
+
+/*
+ * Cast a function's input parameter list from agtype to that function's input
+ * type. This is used for functions that don't take agtype as input and where
+ * there isn't an implicit cast to do this for us.
+ */
+static List *cast_agtype_input_to_other_type(cypher_parsestate *cpstate,
+                                             FuncCall *fn, List *targs)
+{
+    char *funcname = (((String*)linitial(fn->funcname))->sval);
+    int nargs = fn->args->length;
+    CatCList *catlist = NULL;
+    List *new_targs = NIL;
+    ListCell *lc = NULL;
+    int i = 0;
+
+    /* get a list of matching functions from the sys cache */
+    catlist = SearchSysCacheList1(PROCNAMEARGSNSP, CStringGetDatum(funcname));
+
+    /* iterate through the list of functions for ones that match */
+    for (i = 0; i < catlist->n_members; i++)
+    {
+        HeapTuple proctup = &catlist->members[i]->tuple;
+        Form_pg_proc procform = (Form_pg_proc) GETSTRUCT(proctup);
+
+        /* check that the names, number of args, and variadic match */
+        if (pg_strcasecmp(funcname, procform->proname.data) == 0 &&
+            nargs == procform->pronargs &&
+            fn->func_variadic == procform->provariadic)
+        {
+            Oid *proargtypes = procform->proargtypes.values;
+            int j = 0;
+
+            /*
+             * Rebuild targs with castings to the function's input types from
+             * targ's output type.
+             */
+            foreach(lc, targs)
+            {
+                Oid poid = proargtypes[j];
+                Node *targ = lfirst(lc);
+                Oid toid = exprType(targ);
+
+                /* cast the arg. this will error out if it can't be done. */
+                targ = cast_input_to_output_type(cpstate, targ, toid, poid);
+
+                /* add it to the new argument list */
+                new_targs = lappend(new_targs, targ);
+                j++;
+            }
+
+            /* free the old args and replace them with the new ones */
+            pfree(targs);
+            targs = new_targs;
+            break;
+        }
+    }
+    /* we need to release the cache list */
+    ReleaseSysCacheList(catlist);
+    return targs;
+}
+
+/*
+ * Verify that a called function, that is mapped to a specific
+ * function in some other extension, is loaded. Otherwise, bail
+ * out with an error, stating the issue.
+ *
+ * Note: some code borrowed from FuncnameGetCandidates
+ */
+static void check_for_extension_functions(char *extension, FuncCall *fn)
+{
+    char *funcname = (((String*)linitial(fn->funcname))->sval);
+    CatCList *catlist = NULL;
+    bool found = false;
+    int i = 0;
+
+    /* get a list of matching functions */
+    catlist = SearchSysCacheList1(PROCNAMEARGSNSP, CStringGetDatum(funcname));
+
+    /* if the catalog list is empty, the extension isn't loaded */
+    if (catlist->n_members == 0)
+    {
+        ereport(ERROR,
+                (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+                 errmsg("extension %s is not installed for function %s",
+                        extension, funcname)));
+    }
+
+    /* iterate through them and verify that they are in the search path */
+    for (i = 0; i < catlist->n_members; i++)
+    {
+        HeapTuple proctup = &catlist->members[i]->tuple;
+        Form_pg_proc procform = (Form_pg_proc) GETSTRUCT(proctup);
+        List *asp = fetch_search_path(false);
+        ListCell *nsp;
+
+        /*
+         * Consider only procs that are in the search path and are not in
+         * the temp namespace.
+         */
+        foreach(nsp, asp)
+        {
+            Oid oid = lfirst_oid(nsp);
+
+            if (procform->pronamespace == oid &&
+                isTempNamespace(procform->pronamespace) == false)
+            {
+                pfree(asp);
+                found = true;
+                break;
+            }
+        }
+
+        if (found)
+        {
+            break;
+        }
+
+        pfree(asp);
+    }
+
+    /* if we didn't find it, it isn't in the search path */
+    if (!found)
+    {
+        ereport(ERROR,
+                (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+                 errmsg("extension %s is not in search path for function %s",
+                        extension, funcname)));
+    }
+
+    /* release the system cache list */
+    ReleaseSysCacheList(catlist);
+}
+
+/*
+ * Cast an input type to an output type, error out if not possible.
+ * Thanks to Taha for this idea.
+ */
+static Node *cast_input_to_output_type(cypher_parsestate *cpstate, Node *expr,
+                                       Oid source_oid, Oid target_oid)
+{
+    ParseState *pstate = &cpstate->pstate;
+
+    /* can we cast from source to target oid? */
+    if (can_coerce_type(1, &source_oid, &target_oid, COERCION_EXPLICIT))
+    {
+        /* coerce the source to the target */
+        expr = coerce_type(pstate, expr, source_oid, target_oid, -1,
+                           COERCION_EXPLICIT, COERCE_EXPLICIT_CAST, -1);
+    }
+    /* error out if we can't cast */
+    else
+    {
+        ereport(ERROR,
+                (errcode(ERRCODE_UNDEFINED_FUNCTION),
+                 errmsg("cannot cast type %s to %s", 
format_type_be(source_oid),
+                 format_type_be(target_oid))));
+    }
+
+    /* return the casted expression */
+    return expr;
+}
+
+/*
+ * Due to issues with creating a cast from text to agtype, we need to wrap a
+ * function that outputs text with text_to_agtype.
+ */
+static Node *wrap_text_output_to_agtype(cypher_parsestate *cpstate,
+                                        FuncExpr *fexpr)
+{
+    ParseState *pstate = &cpstate->pstate;
+    Node *last_srf = pstate->p_last_srf;
+    Node *retval = NULL;
+    List *fname = NIL;
+    FuncCall *fnode = NULL;
+
+    if (fexpr->funcresulttype != TEXTOID)
+    {
+        ereport(ERROR,
+                (errcode(ERRCODE_DATA_EXCEPTION),
+                 errmsg("can only wrap text to agtype")));
+    }
+
+    /* make a function call node to cast text to agtype */
+    fname = list_make2(makeString("ag_catalog"), makeString("text_to_agtype"));
+
+    /* the input function is the arg to the new function (wrapper) */
+    fnode = makeFuncCall(fname, list_make1(fexpr), COERCE_SQL_SYNTAX, -1);
+
+    /* ... and hand off to ParseFuncOrColumn to create it */
+    retval = ParseFuncOrColumn(pstate, fname, list_make1(fexpr), last_srf,
+                               fnode, false, -1);
+
+    /* return the wrapped function */
+    return retval;
+}
+
 /*
  * Code borrowed from PG's transformFuncCall and updated for AGE
  */
@@ -1612,6 +1843,7 @@ static Node *transform_FuncCall(cypher_parsestate 
*cpstate, FuncCall *fn)
     List *fname = NIL;
     ListCell *arg;
     Node *retval = NULL;
+    bool found = false;
 
     /* Transform the list of arguments ... */
     foreach(arg, fn->args)
@@ -1626,8 +1858,67 @@ static Node *transform_FuncCall(cypher_parsestate 
*cpstate, FuncCall *fn)
     Assert(!fn->agg_within_group);
 
     /*
-     * If the function name is not qualified, then it is one of ours. We need 
to
-     * construct its name, and qualify it, so that PG can find it.
+     * Check for cypher functions that map to the fuzzystrmatch extension and
+     * verify that the external functions exist.
+     */
+    if (is_fuzzystrmatch_function(fn))
+    {
+        /* abort if the extension isn't loaded or in the path */
+        check_for_extension_functions("fuzzystrmatch", fn);
+
+        /* everything looks good so mark found as true */
+        found = true;
+    }
+
+    /*
+     * If we found a function that is part of an extension, which is in the
+     * search_path, then cast the agtype inputs to that function's type inputs.
+     */
+    if (found)
+    {
+        FuncExpr *fexpr = NULL;
+
+        /*
+         * Coerce agtype inputs to function's inputs. this will error out if
+         * this is not possible to do.
+         */
+        targs = cast_agtype_input_to_other_type(cpstate, fn, targs);
+
+        /* now get the function node for the external function */
+        fexpr = (FuncExpr *)ParseFuncOrColumn(pstate, fn->funcname, targs,
+                                              last_srf, fn, false,
+                                              fn->location);
+
+        /*
+         * This will cast TEXT outputs to AGTYPE. It will error out if this is
+         * not possible to do. For TEXT to AGTYPE we need to wrap the output
+         * due to issues with creating a cast from TEXT to AGTYPE.
+         */
+        if (fexpr->funcresulttype == TEXTOID)
+        {
+            retval = wrap_text_output_to_agtype(cpstate, fexpr);
+        }
+        else
+        {
+            retval = (Node *)fexpr;
+        }
+
+        /* additional casts or wraps can be done here for other types */
+
+        /* flag that an aggregate was found during a transform */
+        if (retval != NULL && retval->type == T_Aggref)
+        {
+            cpstate->exprHasAgg = true;
+        }
+
+        /* we can just return it here */
+        return retval;
+    }
+
+    /*
+     * If the function name is not qualified and not from an extension, then it
+     * is one of ours. We need to construct its name, and qualify it, so that 
PG
+     * can find it.
      */
     if (list_length(fn->funcname) == 1)
     {
@@ -1645,7 +1936,9 @@ static Node *transform_FuncCall(cypher_parsestate 
*cpstate, FuncCall *fn)
          * in lower case.
          */
         for (i = 0; i < pnlen; i++)
+        {
             ag_name[i + 4] = tolower(name[i]);
+        }
 
         /* terminate it with 0 */
         ag_name[i + 4] = 0;
@@ -1661,9 +1954,9 @@ static Node *transform_FuncCall(cypher_parsestate 
*cpstate, FuncCall *fn)
          */
         if ((list_length(targs) != 0) &&
             (strcmp("startNode", name) == 0 ||
-              strcmp("endNode", name) == 0 ||
-              strcmp("vle", name) == 0 ||
-              strcmp("vertex_stats", name) == 0))
+             strcmp("endNode", name) == 0 ||
+             strcmp("vle", name) == 0 ||
+             strcmp("vertex_stats", name) == 0))
         {
             char *graph_name = cpstate->graph_name;
             Datum d = string_to_agtype(graph_name);
diff --git a/src/backend/utils/adt/agtype.c b/src/backend/utils/adt/agtype.c
index 3b4af7f9..91814afe 100644
--- a/src/backend/utils/adt/agtype.c
+++ b/src/backend/utils/adt/agtype.c
@@ -3207,6 +3207,49 @@ Datum agtype_to_text(PG_FUNCTION_ARGS)
     PG_RETURN_TEXT_P(text_value);
 }
 
+PG_FUNCTION_INFO_V1(text_to_agtype);
+
+/*
+ * Cast text to agtype.
+ */
+Datum text_to_agtype(PG_FUNCTION_ARGS)
+{
+    agtype *result = NULL;
+    agtype_value agtv;
+    text *text_value = NULL;
+    char *string = NULL;
+    int len = 0;
+
+    if (PG_ARGISNULL(0))
+    {
+        PG_RETURN_NULL();
+    }
+
+    /* get the text value */
+    text_value = PG_GETARG_TEXT_PP(0);
+    /* convert it to a string */
+    string = text_to_cstring(text_value);
+    /* get the length */
+    len = strlen(string);
+
+    /* create a temporary agtype string */
+    agtv.type = AGTV_STRING;
+    agtv.val.string.len = len;
+    agtv.val.string.val = pstrdup(string);
+
+    /* free the string */
+    pfree(string);
+
+    /* convert to agtype */
+    result = agtype_value_to_agtype(&agtv);
+
+    /* free the input arg if necessary */
+    PG_FREE_IF_COPY(text_value, 0);
+
+    /* return our result */
+    PG_RETURN_POINTER(result);
+}
+
 PG_FUNCTION_INFO_V1(agtype_to_json);
 
 /*
@@ -3271,13 +3314,22 @@ Datum float8_to_agtype(PG_FUNCTION_ARGS)
 PG_FUNCTION_INFO_V1(int8_to_agtype);
 
 /*
- * Cast float8 to agtype.
+ * Cast int8 to agtype.
  */
 Datum int8_to_agtype(PG_FUNCTION_ARGS)
 {
     return integer_to_agtype(PG_GETARG_INT64(0));
 }
 
+PG_FUNCTION_INFO_V1(int4_to_agtype);
+/*
+ * Cast int to agtype.
+ */
+Datum int4_to_agtype(PG_FUNCTION_ARGS)
+{
+    return integer_to_agtype((int64)PG_GETARG_INT32(0));
+}
+
 PG_FUNCTION_INFO_V1(agtype_to_int4_array);
 
 /*

Reply via email to