Changeset: 431b36d9951d for MonetDB URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=431b36d9951d Modified Files: sql/backends/monet5/rel_weld.c Branch: rel-weld Log Message:
weld impl for sys.avg diffs (152 lines): diff --git a/sql/backends/monet5/rel_weld.c b/sql/backends/monet5/rel_weld.c --- a/sql/backends/monet5/rel_weld.c +++ b/sql/backends/monet5/rel_weld.c @@ -136,10 +136,15 @@ get_weld_cmp(int cmp) { } } +static str +get_func_name(sql_subfunc *f) { + return f->func->imp ? f->func->imp : f->func->base.name; +} + static str get_weld_func(sql_subfunc *f) { - str name = f->func->imp ? f->func->imp : f->func->base.name; - if (strcmp(name, "+") == 0 || strcmp(name, "sum") == 0 || strcmp(name, "count") == 0) + str name = get_func_name(f); + if (strcmp(name, "+") == 0 || strcmp(name, "sum") == 0 || strcmp(name, "count") == 0 || strcmp(name, "avg") == 0) return "+"; else if (strcmp(name, "-") == 0) return "-"; @@ -736,11 +741,11 @@ groupby_produce(backend *be, sql_rel *re { char new_builder[STR_BUF_SIZE]; str col_name; - int len = 0, i, col_count, aggr_count; + int len = 0, i, col_count, aggr_count, result_var, have_avg = 0; node *en; sql_exp *exp; str aggr_func = NULL; - produce_func input_produce; + produce_func input_produce; list *group_by_exps = sa_list(wstate->sa); list *aggr_exps = sa_list(wstate->sa); list *project_exps = sa_list(wstate->sa); @@ -754,6 +759,16 @@ groupby_produce(backend *be, sql_rel *re exp = en->data; if (exp->type == e_aggr) { list_append(aggr_exps, exp); + if (aggr_func == NULL) { + aggr_func = get_weld_func(exp->f); + } else if (aggr_func != get_weld_func(exp->f)) { + /* Currently Weld only supports a single operation for mergers */ + wstate->error = 1; + goto cleanup; + } + if (strcmp(get_func_name(exp->f), "avg") == 0) { + have_avg = 1; + } } else { if (exp->type == e_column && exp->r && exp->name && strcmp((str)exp->r, exp->name) != 0) { list_append(project_exps, exp); @@ -764,7 +779,7 @@ groupby_produce(backend *be, sql_rel *re } /* Create a new builder */ wstate->num_parens = wstate->num_loops = 0; - int result_var = wstate->next_var++; + result_var = wstate->next_var++; wprintf(wstate, "let v%d = (", result_var); wstate->num_parens++; len = 0; @@ -792,25 +807,21 @@ groupby_produce(backend *be, sql_rel *re } else { len += sprintf(new_builder + len, "merger["); } - if (list_length(aggr_exps) > 1) { + if (list_length(aggr_exps) > 1 || have_avg) { len += sprintf(new_builder + len, "{"); /* value is a struct */ } for (en = aggr_exps->h; en; en = en->next) { exp = en->data; int type = exp_subtype(exp)->type->localtype; - if (aggr_func == NULL) { - aggr_func = get_weld_func(exp->f); - } else if (aggr_func != get_weld_func(exp->f)) { - /* Currently Weld only supports a single operation for mergers */ - wstate->error = 1; - goto cleanup; + len += sprintf(new_builder + len, "%s", getWeldType(type)); + if (strcmp(get_func_name(exp->f), "avg") == 0) { + len += sprintf(new_builder + len, ", i64"); } - len += sprintf(new_builder + len, "%s", getWeldType(type)); if (en->next != NULL) { len += sprintf(new_builder + len, ", "); } } - if (list_length(aggr_exps) > 1) { + if (list_length(aggr_exps) > 1 || have_avg) { len += sprintf(new_builder + len, "}"); /* value is a struct */ } len += sprintf(new_builder + len, ", %s]", aggr_func); @@ -846,7 +857,7 @@ groupby_produce(backend *be, sql_rel *re } wprintf(wstate, ", "); } - if (list_length(aggr_exps) > 1) { + if (list_length(aggr_exps) > 1 || have_avg) { wprintf(wstate, "{"); /* value is a struct */ } for (en = aggr_exps->h; en; en = en->next) { @@ -856,11 +867,14 @@ groupby_produce(backend *be, sql_rel *re wprintf(wstate, "%s(", weld_type); exp_to_weld(be, wstate, exp); wprintf(wstate, ")"); + if (strcmp(get_func_name(exp->f), "avg") == 0) { + wprintf(wstate, ", 1L"); + } if (en->next != NULL) { wprintf(wstate, ", "); } } - if (list_length(aggr_exps) > 1) { + if (list_length(aggr_exps) > 1 || have_avg) { wprintf(wstate, "}"); /* value is a struct */ } if (list_length(group_by_exps) > 0) { @@ -876,7 +890,7 @@ groupby_produce(backend *be, sql_rel *re wstate->num_parens = old_num_parens; wstate->num_loops = old_num_loops; wstate->builder = old_builder; - char struct_mbr[64]; + char struct_mbr[128]; col_count = aggr_count = 0; if (group_by_exps->h) { wstate->num_loops++; @@ -903,15 +917,23 @@ groupby_produce(backend *be, sql_rel *re } else if (list_find(aggr_exps, exp, NULL) && list_length(group_by_exps) > 0) { /* Set the name of an aggregate if we used a dictmerger */ len = sprintf(struct_mbr, "n%d.$1", wstate->num_loops); - if (list_length(aggr_exps) > 1) { + if (list_length(aggr_exps) > 1 || have_avg) { len += sprintf(struct_mbr + len, ".$%d", aggr_count++); } + if (strcmp(get_func_name(exp->f), "avg") == 0) { + /* Divide by the count */ + len += sprintf(struct_mbr + len, " / f64(n%d.$1.$%d)", wstate->num_loops, aggr_count++); + } } else { /* Set the name of an aggregate if we used a merger */ len = sprintf(struct_mbr, "v%d", wstate->next_var); - if (list_length(aggr_exps) > 1) { + if (list_length(aggr_exps) > 1 || have_avg) { len += sprintf(struct_mbr + len, ".$%d", col_count++); } + if (strcmp(get_func_name(exp->f), "avg") == 0) { + /* Divide by the count */ + len += sprintf(struct_mbr + len, " / f64(v%d.$%d)", wstate->next_var, aggr_count++); + } } if (exp_subtype(exp)->type->localtype == TYPE_str) { wprintf(wstate, "let %s = strslice(%s_strcol, i64(%s) + %s_stroffset);", _______________________________________________ checkin-list mailing list checkin-list@monetdb.org https://www.monetdb.org/mailman/listinfo/checkin-list