Changeset: 431b36d9951d for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=431b36d9951d
Modified Files:
        sql/backends/monet5/rel_weld.c
Branch: rel-weld
Log Message:

weld impl for sys.avg


diffs (152 lines):

diff --git a/sql/backends/monet5/rel_weld.c b/sql/backends/monet5/rel_weld.c
--- a/sql/backends/monet5/rel_weld.c
+++ b/sql/backends/monet5/rel_weld.c
@@ -136,10 +136,15 @@ get_weld_cmp(int cmp) {
        }
 }
 
+static str 
+get_func_name(sql_subfunc *f) {
+       return f->func->imp ? f->func->imp : f->func->base.name;
+}
+
 static str
 get_weld_func(sql_subfunc *f) {
-       str name = f->func->imp ? f->func->imp : f->func->base.name;
-       if (strcmp(name, "+") == 0 || strcmp(name, "sum") == 0 || strcmp(name, 
"count") == 0)
+       str name = get_func_name(f);
+       if (strcmp(name, "+") == 0 || strcmp(name, "sum") == 0 || strcmp(name, 
"count") == 0 || strcmp(name, "avg") == 0)
                return "+";
        else if (strcmp(name, "-") == 0)
                return "-";
@@ -736,11 +741,11 @@ groupby_produce(backend *be, sql_rel *re
 {
        char new_builder[STR_BUF_SIZE];
        str col_name;
-       int len = 0, i, col_count, aggr_count;
+       int len = 0, i, col_count, aggr_count, result_var, have_avg = 0;
        node *en;
        sql_exp *exp;
        str aggr_func = NULL;
-       produce_func  input_produce;
+       produce_func input_produce;
        list *group_by_exps = sa_list(wstate->sa);
        list *aggr_exps = sa_list(wstate->sa);
        list *project_exps = sa_list(wstate->sa);
@@ -754,6 +759,16 @@ groupby_produce(backend *be, sql_rel *re
                exp = en->data;
                if (exp->type == e_aggr) {
                        list_append(aggr_exps, exp);
+                       if (aggr_func == NULL) {
+                               aggr_func = get_weld_func(exp->f);
+                       } else if (aggr_func != get_weld_func(exp->f)) {
+                               /* Currently Weld only supports a single 
operation for mergers */
+                               wstate->error = 1;
+                               goto cleanup;
+                       }
+                       if (strcmp(get_func_name(exp->f), "avg") == 0) {
+                               have_avg = 1;
+                       }
                } else {
                        if (exp->type == e_column && exp->r && exp->name && 
strcmp((str)exp->r, exp->name) != 0) {
                                list_append(project_exps, exp);
@@ -764,7 +779,7 @@ groupby_produce(backend *be, sql_rel *re
        }
        /* Create a new builder */
        wstate->num_parens = wstate->num_loops = 0;
-       int result_var = wstate->next_var++;
+       result_var = wstate->next_var++;
        wprintf(wstate, "let v%d = (", result_var);
        wstate->num_parens++;
        len = 0;
@@ -792,25 +807,21 @@ groupby_produce(backend *be, sql_rel *re
        } else {
                len += sprintf(new_builder + len, "merger[");
        }
-       if (list_length(aggr_exps) > 1) {
+       if (list_length(aggr_exps) > 1 || have_avg) {
                len += sprintf(new_builder + len, "{"); /* value is a struct */
        }
        for (en = aggr_exps->h; en; en = en->next) {
                exp = en->data;
                int type = exp_subtype(exp)->type->localtype;
-               if (aggr_func == NULL) {
-                       aggr_func = get_weld_func(exp->f);
-               } else if (aggr_func != get_weld_func(exp->f)) {
-                       /* Currently Weld only supports a single operation for 
mergers */
-                       wstate->error = 1;
-                       goto cleanup;
+               len += sprintf(new_builder + len, "%s", getWeldType(type));
+               if (strcmp(get_func_name(exp->f), "avg") == 0) {
+                       len += sprintf(new_builder + len, ", i64");
                }
-               len += sprintf(new_builder + len, "%s", getWeldType(type));
                if (en->next != NULL) {
                        len += sprintf(new_builder + len, ", ");
                }
        }
-       if (list_length(aggr_exps) > 1) {
+       if (list_length(aggr_exps) > 1 || have_avg) {
                len += sprintf(new_builder + len, "}"); /* value is a struct */
        }
        len += sprintf(new_builder + len, ", %s]", aggr_func);
@@ -846,7 +857,7 @@ groupby_produce(backend *be, sql_rel *re
                }
                wprintf(wstate, ", ");
        }
-       if (list_length(aggr_exps) > 1) {
+       if (list_length(aggr_exps) > 1 || have_avg) {
                wprintf(wstate, "{"); /* value is a struct */
        }
        for (en = aggr_exps->h; en; en = en->next) {
@@ -856,11 +867,14 @@ groupby_produce(backend *be, sql_rel *re
                wprintf(wstate, "%s(", weld_type);
                exp_to_weld(be, wstate, exp);
                wprintf(wstate, ")");
+               if (strcmp(get_func_name(exp->f), "avg") == 0) {
+                       wprintf(wstate, ", 1L");
+               }
                if (en->next != NULL) {
                        wprintf(wstate, ", ");
                }
        }
-       if (list_length(aggr_exps) > 1) {
+       if (list_length(aggr_exps) > 1 || have_avg) {
                wprintf(wstate, "}"); /* value is a struct */
        }
        if (list_length(group_by_exps) > 0) {
@@ -876,7 +890,7 @@ groupby_produce(backend *be, sql_rel *re
        wstate->num_parens = old_num_parens;
        wstate->num_loops = old_num_loops;
        wstate->builder = old_builder;
-       char struct_mbr[64];
+       char struct_mbr[128];
        col_count = aggr_count = 0;
        if (group_by_exps->h) {
                wstate->num_loops++;
@@ -903,15 +917,23 @@ groupby_produce(backend *be, sql_rel *re
                } else if (list_find(aggr_exps, exp, NULL) && 
list_length(group_by_exps) > 0) {
                        /* Set the name of an aggregate if we used a dictmerger 
*/
                        len = sprintf(struct_mbr, "n%d.$1", wstate->num_loops);
-                       if (list_length(aggr_exps) > 1) {
+                       if (list_length(aggr_exps) > 1 || have_avg) {
                                len += sprintf(struct_mbr + len, ".$%d", 
aggr_count++);
                        }
+                       if (strcmp(get_func_name(exp->f), "avg") == 0) {
+                               /* Divide by the count */
+                               len += sprintf(struct_mbr + len, " / 
f64(n%d.$1.$%d)", wstate->num_loops, aggr_count++);
+                       }
                } else {
                        /* Set the name of an aggregate if we used a merger */
                        len = sprintf(struct_mbr, "v%d", wstate->next_var);
-                       if (list_length(aggr_exps) > 1) {
+                       if (list_length(aggr_exps) > 1 || have_avg) {
                                len += sprintf(struct_mbr + len, ".$%d", 
col_count++);
                        }
+                       if (strcmp(get_func_name(exp->f), "avg") == 0) {
+                               /* Divide by the count */
+                               len += sprintf(struct_mbr + len, " / 
f64(v%d.$%d)", wstate->next_var, aggr_count++);
+                       }
                }
                if (exp_subtype(exp)->type->localtype == TYPE_str) {
                        wprintf(wstate, "let %s = strslice(%s_strcol, i64(%s) + 
%s_stroffset);", 
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to