diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c index 8b22c30559..1268ea92b6 100644 --- a/src/backend/optimizer/plan/planner.c +++ b/src/backend/optimizer/plan/planner.c @@ -1774,9 +1774,10 @@ grouping_planner(PlannerInfo *root, double tuple_fraction, sort_input_target = linitial_node(PathTarget, sort_input_targets); Assert(!linitial_int(sort_input_targets_contain_srfs)); /* likewise for grouping_target vs. scanjoin_target */ - split_pathtarget_at_srfs(root, grouping_target, scanjoin_target, - &grouping_targets, - &grouping_targets_contain_srfs); + split_pathtarget_at_srfs_grouping(root, + grouping_target, scanjoin_target, + &grouping_targets, + &grouping_targets_contain_srfs); grouping_target = linitial_node(PathTarget, grouping_targets); Assert(!linitial_int(grouping_targets_contain_srfs)); /* scanjoin_target will not have any SRFs precomputed for it */ diff --git a/src/backend/optimizer/util/tlist.c b/src/backend/optimizer/util/tlist.c index af7b19be5c..88eb26bf1c 100644 --- a/src/backend/optimizer/util/tlist.c +++ b/src/backend/optimizer/util/tlist.c @@ -19,6 +19,7 @@ #include "optimizer/cost.h" #include "optimizer/optimizer.h" #include "optimizer/tlist.h" +#include "rewrite/rewriteManip.h" /* @@ -45,6 +46,8 @@ typedef struct typedef struct { + PlannerInfo *root; + bool is_grouping_target; /* true if processing grouping target */ /* This is a List of bare expressions: */ List *input_target_exprs; /* exprs available from input */ /* These are Lists of Lists of split_pathtarget_items: */ @@ -59,6 +62,12 @@ typedef struct Index current_sgref; /* current subexpr's sortgroupref, or 0 */ } split_pathtarget_context; +static void split_pathtarget_at_srfs_extended(PlannerInfo *root, + PathTarget *target, + PathTarget *input_target, + List **targets, + List **targets_contain_srfs, + bool is_grouping_target); static bool split_pathtarget_walker(Node *node, split_pathtarget_context *context); static void add_sp_item_to_pathtarget(PathTarget *target, @@ -822,6 +831,51 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target) /* * split_pathtarget_at_srfs + * Split given PathTarget into multiple levels to position SRFs safely, + * performing exact matching against input_target. + * + * This is a wrapper for split_pathtarget_at_srfs_extended() that is used when + * both targets are on the same side of the grouping boundary (i.e., both are + * pre-grouping or both are post-grouping). In this case, no special handling + * for the grouping nulling bit is required. + * + * See split_pathtarget_at_srfs_extended() for more details. + */ +void +split_pathtarget_at_srfs(PlannerInfo *root, + PathTarget *target, PathTarget *input_target, + List **targets, List **targets_contain_srfs) +{ + split_pathtarget_at_srfs_extended(root, target, input_target, + targets, targets_contain_srfs, + false); +} + +/* + * split_pathtarget_at_srfs_grouping + * Split given PathTarget into multiple levels to position SRFs safely, + * ignoring the grouping nulling bit when matching against input_target. + * + * This variant is used when the targets cross the grouping boundary (i.e., + * target is post-grouping while input_target is pre-grouping). In this case, + * we need to ignore the grouping nulling bit when checking for expression + * availability to avoid incorrectly re-evaluating SRFs that have already been + * computed in input_target. + * + * See split_pathtarget_at_srfs_extended() for more details. + */ +void +split_pathtarget_at_srfs_grouping(PlannerInfo *root, + PathTarget *target, PathTarget *input_target, + List **targets, List **targets_contain_srfs) +{ + split_pathtarget_at_srfs_extended(root, target, input_target, + targets, targets_contain_srfs, + true); +} + +/* + * split_pathtarget_at_srfs_extended * Split given PathTarget into multiple levels to position SRFs safely * * The executor can only handle set-returning functions that appear at the @@ -860,6 +914,13 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target) * already meant as a reference to a lower subexpression). So, don't expand * any tlist expressions that appear in input_target, if that's not NULL. * + * This check requires extra care when processing the grouping target + * (indicated by the is_grouping_target flag). In this case input_target is + * pre-grouping while target is post-grouping, so the latter may carry + * nullingrels bits from the grouping step that are absent in the former. We + * must ignore those bits to correctly recognize that the tlist expressions are + * available in input_target. + * * It's also important that we preserve any sortgroupref annotation appearing * in the given target, especially on expressions matching input_target items. * @@ -877,10 +938,11 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target) * are only a few possible patterns for which levels contain SRFs. * But this representation decouples callers from that knowledge. */ -void -split_pathtarget_at_srfs(PlannerInfo *root, - PathTarget *target, PathTarget *input_target, - List **targets, List **targets_contain_srfs) +static void +split_pathtarget_at_srfs_extended(PlannerInfo *root, + PathTarget *target, PathTarget *input_target, + List **targets, List **targets_contain_srfs, + bool is_grouping_target) { split_pathtarget_context context; int max_depth; @@ -905,7 +967,12 @@ split_pathtarget_at_srfs(PlannerInfo *root, return; } - /* Pass any input_target exprs down to split_pathtarget_walker() */ + /* + * Pass 'root', the is_grouping_target flag, and any input_target exprs + * down to split_pathtarget_walker(). + */ + context.root = root; + context.is_grouping_target = is_grouping_target; context.input_target_exprs = input_target ? input_target->exprs : NIL; /* @@ -1076,9 +1143,27 @@ split_pathtarget_at_srfs(PlannerInfo *root, static bool split_pathtarget_walker(Node *node, split_pathtarget_context *context) { + Node *sanitized_node = node; + if (node == NULL) return false; + /* + * If we are crossing the grouping boundary (post-grouping target vs + * pre-grouping input_target), we must ignore the grouping nulling bit to + * correctly check if the subexpression is available in input_target. This + * aligns with the matching logic in set_upper_references(). + */ + if (context->is_grouping_target && + context->root->parse->hasGroupRTE && + context->root->parse->groupingSets != NIL) + { + sanitized_node = + remove_nulling_relids(node, + bms_make_singleton(context->root->group_rtindex), + NULL); + } + /* * A subexpression that matches an expression already computed in * input_target can be treated like a Var (which indeed it will be after @@ -1087,7 +1172,7 @@ split_pathtarget_walker(Node *node, split_pathtarget_context *context) * substructure. (Note in particular that this preserves the identity of * any expressions that appear as sortgrouprefs in input_target.) */ - if (list_member(context->input_target_exprs, node)) + if (list_member(context->input_target_exprs, sanitized_node)) { split_pathtarget_item *item = palloc_object(split_pathtarget_item); diff --git a/src/include/optimizer/tlist.h b/src/include/optimizer/tlist.h index 9c355bf117..8d2c1ec08b 100644 --- a/src/include/optimizer/tlist.h +++ b/src/include/optimizer/tlist.h @@ -48,6 +48,11 @@ extern void apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target); extern void split_pathtarget_at_srfs(PlannerInfo *root, PathTarget *target, PathTarget *input_target, List **targets, List **targets_contain_srfs); +extern void split_pathtarget_at_srfs_grouping(PlannerInfo *root, + PathTarget *target, + PathTarget *input_target, + List **targets, + List **targets_contain_srfs); /* Convenience macro to get a PathTarget with valid cost/width fields */ #define create_pathtarget(root, tlist) \ diff --git a/src/test/regress/expected/groupingsets.out b/src/test/regress/expected/groupingsets.out index 39d35a195b..f047db7c58 100644 --- a/src/test/regress/expected/groupingsets.out +++ b/src/test/regress/expected/groupingsets.out @@ -2563,4 +2563,71 @@ group by grouping sets((a, b), (a)); 2 | 2 | 4 (4 rows) +-- test handling of SRFs with grouping sets +explain (verbose, costs off) +select generate_series(1, a) as g +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(g) +order by 1; + QUERY PLAN +-------------------------------------------------------------- + Sort + Output: (generate_series(1, "*VALUES*".column1)) + Sort Key: (generate_series(1, "*VALUES*".column1)) + -> MixedAggregate + Output: (generate_series(1, "*VALUES*".column1)) + Hash Key: generate_series(1, "*VALUES*".column1) + Group Key: () + -> ProjectSet + Output: generate_series(1, "*VALUES*".column1) + -> Values Scan on "*VALUES*" + Output: "*VALUES*".column1 +(11 rows) + +select generate_series(1, a) as g +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(g) +order by 1; + g +--- + 1 + 2 + +(3 rows) + +explain (verbose, costs off) +select generate_series(1, a) as g, a+b as ab +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(a, ab) +order by 1, 2; + QUERY PLAN +------------------------------------------------------------------------------------------------------------------------- + Sort + Output: (generate_series(1, "*VALUES*".column1)), (("*VALUES*".column1 + "*VALUES*".column2)), "*VALUES*".column1 + Sort Key: (generate_series(1, "*VALUES*".column1)), (("*VALUES*".column1 + "*VALUES*".column2)) + -> ProjectSet + Output: generate_series(1, "*VALUES*".column1), (("*VALUES*".column1 + "*VALUES*".column2)), "*VALUES*".column1 + -> MixedAggregate + Output: "*VALUES*".column1, (("*VALUES*".column1 + "*VALUES*".column2)) + Hash Key: "*VALUES*".column1, ("*VALUES*".column1 + "*VALUES*".column2) + Hash Key: "*VALUES*".column1 + Group Key: () + -> Values Scan on "*VALUES*" + Output: ("*VALUES*".column1 + "*VALUES*".column2), "*VALUES*".column1 +(12 rows) + +select generate_series(1, a) as g, a+b as ab +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(a, ab) +order by 1, 2; + g | ab +---+---- + 1 | 2 + 1 | 4 + 1 | + 1 | + 2 | 4 + 2 | +(6 rows) + -- end diff --git a/src/test/regress/sql/groupingsets.sql b/src/test/regress/sql/groupingsets.sql index 6d875475fa..3e010961fa 100644 --- a/src/test/regress/sql/groupingsets.sql +++ b/src/test/regress/sql/groupingsets.sql @@ -721,4 +721,27 @@ select a, b, row_number() over (order by a, b nulls first) from (values (1, 1), (2, 2)) as t (a, b) where a = b group by grouping sets((a, b), (a)); +-- test handling of SRFs with grouping sets +explain (verbose, costs off) +select generate_series(1, a) as g +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(g) +order by 1; + +select generate_series(1, a) as g +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(g) +order by 1; + +explain (verbose, costs off) +select generate_series(1, a) as g, a+b as ab +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(a, ab) +order by 1, 2; + +select generate_series(1, a) as g, a+b as ab +from (values (1, 1), (2, 2)) as t (a, b) +group by rollup(a, ab) +order by 1, 2; + -- end