From a0ac75ec9a92924901c815641cd451fad3534ea6 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 30 Oct 2025 19:32:18 -0700 Subject: [PATCH 1/2] saving --- pom.xml | 4 +- .../Cockroach/CockroachGenerator.java | 667 +++++++++++++++++- .../Backends/Cockroach/CockroachTester.java | 4 +- .../org/qed/Backends/Cockroach/CockroachTests | 291 +++++++- .../Generated/AggregateExtractProject.opt | 19 + .../Generated/AggregateJoinJoinRemove.opt | 54 ++ .../Generated/AggregateJoinRemove.opt | 28 + .../Cockroach/Generated/FilterReduceFalse.opt | 11 + .../Cockroach/Generated/JoinExtractFilter.opt | 17 + .../Cockroach/Generated/JoinReduceFalse.opt | 18 + .../Cockroach/Generated/JoinReduceTrue.opt | 18 + .../Cockroach/Generated/MinusMerge.opt | 20 + .../Cockroach/Generated/PruneEmptyFilter.opt | 7 + .../Generated/PruneEmptyIntersect.opt | 7 + .../Cockroach/Generated/PruneEmptyMinus.opt | 8 + .../Cockroach/Generated/PruneEmptyProject.opt | 8 + .../Cockroach/Generated/PruneEmptyUnion.opt | 7 + .../RRuleInstances/PruneZeroRowsTable.java | 18 - 18 files changed, 1148 insertions(+), 58 deletions(-) create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt delete mode 100644 src/main/java/org/qed/RRuleInstances/PruneZeroRowsTable.java diff --git a/pom.xml b/pom.xml index 3802e5f..d60c9a4 100644 --- a/pom.xml +++ b/pom.xml @@ -51,8 +51,8 @@ org.apache.maven.plugins maven-compiler-plugin - 23 - 23 + 25 + 25 --enable-preview diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java b/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java index 20fc7db..c2909d0 100644 --- a/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java @@ -27,11 +27,28 @@ public Env onMatchFilter(Env env, RelRN.Filter filter) { String sourcePattern = sourceEnv.current(); Env condEnv = onMatch(sourceEnv, filter.cond()); String condPattern; + + // Check for PruneEmptyFilter pattern: filter on empty source + if (filter.source() instanceof RelRN.Empty) { + String inputVar = condEnv.generateVar("input"); + String filtersVar = condEnv.generateVar("filters"); + String pattern = "(Select\n $" + inputVar + ":* & (HasZeroRows $" + inputVar + ")\n $" + filtersVar + ":*\n)"; + return condEnv.addBinding("isPruneEmptyFilter", "true") + .addBinding("pruneEmptyInput", inputVar) + .setPattern(pattern).focus(pattern); + } + if (filter.cond() instanceof RexRN.True) { String pattern = "(Select\n " + sourcePattern + "\n []\n)"; return condEnv.setPattern(pattern).focus(pattern); } else if (filter.cond() instanceof RexRN.False) { - condPattern = condEnv.pattern(); + // FilterReduceFalse pattern: Select with False condition + String onVar = condEnv.generateVar("on"); + Env onEnv = condEnv.addBinding("on", onVar); + String itemVar = onEnv.generateVar("item"); + Env itemEnv = onEnv.addBinding("item", itemVar); + String pattern = "(Select\n " + sourcePattern + "\n $" + onVar + ":[\n ...\n $" + itemVar + ":(FiltersItem (False))\n ...\n ]\n)"; + return itemEnv.setPattern(pattern).focus(pattern); } else { condPattern = condEnv.current(); } @@ -40,6 +57,18 @@ public Env onMatchFilter(Env env, RelRN.Filter filter) { } public Env onMatchProject(Env env, RelRN.Project project) { + // Generic handling for Project over empty input (PruneEmptyProject) + if (project.source() instanceof RelRN.Empty) { + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("zeroInput", inputVar) + .addBinding("hasZeroRows", "true"); + String projectionsVar = inputEnv.generateVar("projections"); + Env projectionsEnv = inputEnv.addBinding("projections", projectionsVar); + String passthroughVar = projectionsEnv.generateVar("passthrough"); + Env passthroughEnv = projectionsEnv.addBinding("passthrough", passthroughVar); + String pattern = "(Project\n $" + inputVar + ":* & (HasZeroRows $" + inputVar + ")\n $" + projectionsVar + ":*\n $" + passthroughVar + ":*\n)"; + return passthroughEnv.setPattern(pattern).focus(pattern); + } if (env.rulename.equals("ProjectMerge") && project.source() instanceof RelRN.Project) { Env outerProjEnv = onMatch(env, project.map()); String outerProjPattern = outerProjEnv.current(); @@ -70,6 +99,61 @@ public Env onMatchProject(Env env, RelRN.Project project) { @Override public Env onMatchJoin(Env env, RelRN.Join join) { + // Check for JoinReduceTrue/JoinReduceFalse patterns + if (join.cond() instanceof RexRN.And and) { + if (and.sources().size() == 2) { + boolean hasTrue = false; + boolean hasFalse = false; + RexRN otherCond = null; + + for (RexRN source : and.sources()) { + if (source instanceof RexRN.True) { + hasTrue = true; + } else if (source instanceof RexRN.False) { + hasFalse = true; + } else { + otherCond = source; + } + } + + if (hasTrue && otherCond != null) { + // JoinReduceTrue pattern: And(cond, True) -> cond + Env leftEnv = onMatch(env, join.left()); + String leftPattern = leftEnv.current(); + Env rightEnv = onMatch(leftEnv, join.right()); + String rightPattern = rightEnv.current(); + String onVar = rightEnv.generateVar("on"); + Env onEnv = rightEnv.addBinding("on", onVar); + String itemVar = onEnv.generateVar("item"); + Env itemEnv = onEnv.addBinding("item", itemVar); + String privateVar = itemEnv.generateVar("private"); + Env privateEnv = itemEnv.addBinding("private_" + System.identityHashCode(join), privateVar) + .addBinding("last_private", privateVar) + .addBinding("joinReduceTrue", "true"); + String joinType = getJoinType(join.ty().semantics()); + String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n $" + onVar + ":[\n ...\n $" + itemVar + ":(FiltersItem (True))\n ...\n ]\n $" + privateVar + ":*\n)"; + return privateEnv.setPattern(pattern).focus(pattern); + } else if (hasFalse && otherCond != null) { + // JoinReduceFalse pattern: And(cond, False) -> False + Env leftEnv = onMatch(env, join.left()); + String leftPattern = leftEnv.current(); + Env rightEnv = onMatch(leftEnv, join.right()); + String rightPattern = rightEnv.current(); + String onVar = rightEnv.generateVar("on"); + Env onEnv = rightEnv.addBinding("on", onVar); + String itemVar = onEnv.generateVar("item"); + Env itemEnv = onEnv.addBinding("item", itemVar); + String privateVar = itemEnv.generateVar("private"); + Env privateEnv = itemEnv.addBinding("private_" + System.identityHashCode(join), privateVar) + .addBinding("last_private", privateVar) + .addBinding("joinReduceFalse", "true"); + String joinType = getJoinType(join.ty().semantics()); + String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n $" + onVar + ":[\n ...\n $" + itemVar + ":(FiltersItem (False))\n ...\n ]\n $" + privateVar + ":*\n)"; + return privateEnv.setPattern(pattern).focus(pattern); + } + } + } + Env leftEnv = onMatch(env, join.left()); String leftPattern = leftEnv.current(); Env rightEnv = onMatch(leftEnv, join.right()); @@ -86,6 +170,33 @@ public Env onMatchJoin(Env env, RelRN.Join join) { @Override public Env transformJoin(Env env, RelRN.Join join) { + // Check for JoinReduceTrue/JoinReduceFalse patterns + if (env.bindings().containsKey("joinReduceTrue")) { + // JoinReduceTrue: simplify to RemoveFiltersItem + Env leftEnv = transform(env, join.left()); + String leftPattern = leftEnv.current(); + Env rightEnv = transform(leftEnv, join.right()); + String rightPattern = rightEnv.current(); + String onVar = rightEnv.bindings().get("on"); + String itemVar = rightEnv.bindings().get("item"); + String privateVar = rightEnv.bindings().getOrDefault("private_" + System.identityHashCode(join), + rightEnv.bindings().getOrDefault("last_private", "private")); + String joinType = getJoinType(join.ty().semantics()); + String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n (RemoveFiltersItem $" + onVar + " $" + itemVar + ")\n $" + privateVar + "\n)"; + return rightEnv.setPattern(pattern).focus(pattern); + } else if (env.bindings().containsKey("joinReduceFalse")) { + // JoinReduceFalse: simplify to FiltersItem (False) + Env leftEnv = transform(env, join.left()); + String leftPattern = leftEnv.current(); + Env rightEnv = transform(leftEnv, join.right()); + String rightPattern = rightEnv.current(); + String privateVar = rightEnv.bindings().getOrDefault("private_" + System.identityHashCode(join), + rightEnv.bindings().getOrDefault("last_private", "private")); + String joinType = getJoinType(join.ty().semantics()); + String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n [ (FiltersItem (False)) ]\n $" + privateVar + "\n)"; + return rightEnv.setPattern(pattern).focus(pattern); + } + Env leftEnv = transform(env, join.left()); String leftPattern = leftEnv.current(); Env rightEnv = transform(leftEnv, join.right()); @@ -105,6 +216,18 @@ public Env transformJoin(Env env, RelRN.Join join) { @Override public Env onMatchUnion(Env env, RelRN.Union union) { + // If both inputs are Empty, emit a HasZeroRows pattern generically + if (union.sources().size() == 2) { + RelRN leftSource = union.sources().get(0); + RelRN rightSource = union.sources().get(1); + if (leftSource instanceof RelRN.Empty && rightSource instanceof RelRN.Empty) { + String leftVar = env.generateVar("left"); + String rightVar = env.generateVar("right"); + String unionType = union.all() ? "UnionAll" : "Union"; + String pattern = "(" + unionType + "\n $" + leftVar + ":* & (HasZeroRows $" + leftVar + ")\n $" + rightVar + ":* & (HasZeroRows $" + rightVar + ")\n)"; + return env.setPattern(pattern).focus(pattern); + } + } Env currentEnv = env; Seq sourcePatterns = Seq.empty(); for (RelRN source : union.sources()) { @@ -141,6 +264,21 @@ private String buildNestedUnion(String unionType, Seq sources, String pr @Override public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + // Check for PruneEmptyIntersect pattern: intersect with empty right source + if (intersect.sources().size() == 2) { + RelRN leftSource = intersect.sources().get(0); + RelRN rightSource = intersect.sources().get(1); + if (rightSource instanceof RelRN.Empty) { + String leftVar = env.generateVar("left"); + String rightVar = env.generateVar("right"); + String intersectType = intersect.all() ? "IntersectAll" : "Intersect"; + String pattern = "(" + intersectType + "\n $" + leftVar + ":*\n $" + rightVar + ":* & (HasZeroRows $" + rightVar + ")\n)"; + return env.addBinding("isPruneEmptyIntersect", "true") + .addBinding("pruneEmptyLeft", leftVar) + .setPattern(pattern).focus(pattern); + } + } + Env currentEnv = env; Seq sourcePatterns = Seq.empty(); for (RelRN source : intersect.sources()) { @@ -176,8 +314,179 @@ private String buildNestedIntersect(String intersectType, Seq sources, S return "(" + intersectType + "\n " + first + "\n " + nested + "\n $" + privatePattern + "\n)"; } + @Override + public Env onMatchMinus(Env env, RelRN.Minus minus) { + // Handle MinusMerge: (Except (Except left rightB pInner) rightC pOuter) + if (minus.sources().size() == 2 && minus.sources().get(0) instanceof RelRN.Minus inner) { + String leftVar = env.generateVar("left"); + Env leftEnv = env.addBinding("left", leftVar); + String rightBVar = leftEnv.generateVar("rightB"); + Env rightBEnv = leftEnv.addBinding("rightB", rightBVar); + String pInnerVar = rightBEnv.generateVar("pInner"); + Env pInnerEnv = rightBEnv.addBinding("pInner", pInnerVar); + String rightCVar = pInnerEnv.generateVar("rightC"); + Env rightCEnv = pInnerEnv.addBinding("rightC", rightCVar); + String pOuterVar = rightCEnv.generateVar("pOuter"); + Env pOuterEnv = rightCEnv.addBinding("pOuter", pOuterVar); + String pattern = "(Except\n" + + " (Except\n" + + " $" + leftVar + ":*\n" + + " $" + rightBVar + ":*\n" + + " $" + pInnerVar + ":*\n" + + " )\n" + + " $" + rightCVar + ":*\n" + + " $" + pOuterVar + ":*\n" + + ")"; + return pOuterEnv.setPattern(pattern).focus(pattern); + } + // Fallback generic formatting + Env leftEnv = onMatch(env, minus.sources().get(0)); + String leftPattern = leftEnv.current(); + Env rightEnv = onMatch(leftEnv, minus.sources().get(1)); + String rightPattern = rightEnv.current(); + String privateVar = rightEnv.generateVar("private"); + Env privateEnv = rightEnv.addBinding("minus_private", privateVar); + String pattern = "(Except\n " + leftPattern + "\n " + rightPattern + "\n $" + privateVar + ":*\n)"; + return privateEnv.setPattern(pattern).focus(pattern); + } + @Override public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + // Handle Aggregate over a nested LeftJoin (LeftJoin (LeftJoin left middle ...) right topOn topPrivate) + if (aggregate.source() instanceof RelRN.Join topJoin + && topJoin.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT + && topJoin.left() instanceof RelRN.Join bottomJoin + && bottomJoin.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT) { + // Bind variables + String topJoinVar = env.generateVar("topJoin"); + Env topEnv = env.addBinding("topJoin", topJoinVar); + String bottomJoinVar = topEnv.generateVar("bottomJoin"); + Env bottomEnv = topEnv.addBinding("bottomJoin", bottomJoinVar); + String leftVar = bottomEnv.generateVar("left"); + Env leftEnv = bottomEnv.addBinding("left", leftVar); + String middleVar = leftEnv.generateVar("middle"); + Env middleEnv = leftEnv.addBinding("middle", middleVar); + String rightVar = middleEnv.generateVar("right"); + Env rightEnv = middleEnv.addBinding("right", rightVar); + String topOnVar = rightEnv.generateVar("topOn"); + Env topOnEnv = rightEnv.addBinding("topOn", topOnVar); + String topPrivateVar = topOnEnv.generateVar("topPrivate"); + Env topPrivateEnv = topOnEnv.addBinding("topPrivate", topPrivateVar); + String aggregationsVar = topPrivateEnv.generateVar("aggregations"); + Env aggsEnv = topPrivateEnv.addBinding("aggregations", aggregationsVar); + String groupingPrivateVar = aggsEnv.generateVar("groupingPrivate"); + Env groupingPrivateEnv = aggsEnv.addBinding("groupingPrivate", groupingPrivateVar); + String groupingColsVar = groupingPrivateEnv.generateVar("groupingCols"); + Env groupingColsEnv = groupingPrivateEnv.addBinding("groupingCols", groupingColsVar); + String orderingVar = groupingColsEnv.generateVar("ordering"); + Env orderingEnv = groupingColsEnv.addBinding("ordering", orderingVar); + + String head = "DistinctOn"; + String pattern = "(" + head + "\n" + + " $" + topJoinVar + ":(LeftJoin\n" + + " $" + bottomJoinVar + ":(LeftJoin $" + leftVar + ":* $" + middleVar + ":* * *) &\n" + + " (JoinPreservesLeftRows $" + bottomJoinVar + ") &\n" + + " (JoinDoesNotDuplicateLeftRows $" + bottomJoinVar + ")\n" + + " $" + rightVar + ":*\n" + + " $" + topOnVar + ":*\n" + + " $" + topPrivateVar + ":*\n" + + " ) &\n" + + " (JoinPreservesLeftRows $" + topJoinVar + ") &\n" + + " (JoinDoesNotDuplicateLeftRows $" + topJoinVar + ")\n" + + " $" + aggregationsVar + ":[]\n" + + " $" + groupingPrivateVar + ":(GroupingPrivate $" + groupingColsVar + ":* $" + orderingVar + ":*) &\n" + + " (ColsAreSubset\n" + + " (UnionCols\n" + + " $" + groupingColsVar + "\n" + + " (AggregationOuterCols $" + aggregationsVar + ")\n" + + " )\n" + + " (UnionCols\n" + + " (OutputCols $" + leftVar + ")\n" + + " (OutputCols $" + rightVar + ")\n" + + " )\n" + + " ) &\n" + + " ^(ColsIntersect\n" + + " (UnionCols\n" + + " $" + groupingColsVar + "\n" + + " (AggregationOuterCols $" + aggregationsVar + ")\n" + + " )\n" + + " (OutputCols $" + middleVar + ")\n" + + " ) &\n" + + " (OrderingCanProjectCols\n" + + " $" + orderingVar + "\n" + + " (UnionCols\n" + + " (OutputCols $" + leftVar + ")\n" + + " (OutputCols $" + rightVar + ")\n" + + " )\n" + + " )\n" + + ")"; + return orderingEnv.setPattern(pattern).focus(pattern); + } + // Special handling for Aggregate over LeftJoin to enable removing the join + if (aggregate.source() instanceof RelRN.Join join && join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT) { + // Allocate and bind variables used across match and transform + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String leftVar = inputEnv.generateVar("left"); + Env leftEnv = inputEnv.addBinding("left", leftVar); + String aggsVar = leftEnv.generateVar("aggregations"); + Env aggsEnv = leftEnv.addBinding("aggregations", aggsVar); + String groupingPrivateVar = aggsEnv.generateVar("groupingPrivate"); + Env groupingPrivateEnv = aggsEnv.addBinding("groupingPrivate", groupingPrivateVar); + String groupingColsVar = groupingPrivateEnv.generateVar("groupingCols"); + Env groupingColsEnv = groupingPrivateEnv.addBinding("groupingCols", groupingColsVar); + String orderingVar = groupingColsEnv.generateVar("ordering"); + Env orderingEnv = groupingColsEnv.addBinding("ordering", orderingVar); + + String head = "DistinctOn"; + String matchPattern = "(" + head + "\n" + + " $" + inputVar + ":(LeftJoin $" + leftVar + ":* * * *) &\n" + + " (JoinPreservesLeftRows $" + inputVar + ") &\n" + + " (JoinDoesNotDuplicateLeftRows $" + inputVar + ")\n" + + " $" + aggsVar + ":[]\n" + + " $" + groupingPrivateVar + ":(GroupingPrivate $" + groupingColsVar + ":* $" + orderingVar + ":*) &\n" + + " (ColsAreSubset\n" + + " (UnionCols\n" + + " $" + groupingColsVar + "\n" + + " (AggregationOuterCols $" + aggsVar + ")\n" + + " )\n" + + " (OutputCols $" + leftVar + ")\n" + + " ) &\n" + + " (OrderingCanProjectCols\n" + + " $" + orderingVar + "\n" + + " (OutputCols $" + leftVar + ")\n" + + " )\n" + + ")"; + return orderingEnv.setPattern(matchPattern).focus(matchPattern); + } + // Check if source is a Project (for AggregateProjectMerge) + if (aggregate.source() instanceof RelRN.Project project) { + Env innerInputEnv = onMatch(env, project.source()); + String innerInputPattern = innerInputEnv.current(); + Env projEnv = onMatch(innerInputEnv, project.map()); + String projPattern = projEnv.current(); + String passthroughVar = projEnv.generateVar("passthrough"); + Env passthroughEnv = projEnv.addBinding("passthrough", passthroughVar); + + // Format as $input:(Project $innerInput:*) pattern (single line) + String inputVar = passthroughEnv.generateVar("input"); + Env inputEnv = passthroughEnv.addBinding("input", inputVar); + String projectPattern = "Project " + innerInputPattern; + String sourcePattern = "$" + inputVar + ":(" + projectPattern + ")"; + + Env aggsEnv = onMatchAggCalls(inputEnv, aggregate.aggCalls()); + String aggsPattern = aggsEnv.current(); + Env groupingEnv = onMatchGroupSet(aggsEnv, aggregate.groupSet()); + String groupingPattern = groupingEnv.current(); + String privateVar = groupingEnv.generateVar("private"); + String innerInputVar = innerInputPattern.replace("$", "").replace(":*", ""); + Env privateEnv = groupingEnv.addBinding("aggregate_private", privateVar) + .addBinding("innerInput", innerInputVar); + String aggregateType = determineAggregateType(aggregate); + String pattern = "(" + aggregateType + "\n " + sourcePattern + "\n " + aggsPattern + "\n $" + privateVar + ":*\n)"; + return privateEnv.setPattern(pattern).focus(pattern); + } + Env sourceEnv = onMatch(env, aggregate.source()); String sourcePattern = sourceEnv.current(); Env aggsEnv = onMatchAggCalls(sourceEnv, aggregate.aggCalls()); @@ -187,6 +496,21 @@ public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { String privateVar = groupingEnv.generateVar("private"); Env privateEnv = groupingEnv.addBinding("aggregate_private", privateVar); String aggregateType = determineAggregateType(aggregate); + + // Check if this is an AggregateExtractProject pattern + boolean hasProjectionExpressions = hasProjectionExpressionsInAggregate(aggregate); + if (hasProjectionExpressions) { + // Generate numbered variables for input, aggregations, groupingPrivate + String inputVar = privateEnv.generateVar("input"); + Env inputEnv = privateEnv.addBinding("input", inputVar); + String aggregationsVar = inputEnv.generateVar("aggregations"); + Env aggsBindEnv = inputEnv.addBinding("aggregations", aggregationsVar); + String groupingPrivateVar = aggsBindEnv.generateVar("groupingPrivate"); + Env gpEnv = aggsBindEnv.addBinding("groupingPrivate", groupingPrivateVar); + String pattern = "(" + aggregateType + "\n $" + inputVar + ":*\n $" + aggregationsVar + ":*\n $" + groupingPrivateVar + ":*\n)"; + return gpEnv.addBinding("isAggregateExtractProject", "true").setPattern(pattern).focus(pattern); + } + String pattern = "(" + aggregateType + "\n " + sourcePattern + "\n " + aggsPattern + "\n $" + privateVar + ":*\n)"; return privateEnv.setPattern(pattern).focus(pattern); } @@ -194,14 +518,32 @@ public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { private Env onMatchAggCalls(Env env, Seq aggCalls) { Env currentEnv = env; Seq aggPatterns = Seq.empty(); + boolean hasProjOperand = false; for (RelRN.AggCall aggCall : aggCalls) { + // Check if aggCall has a Proj operand (for AggregateProjectMerge) + if (aggCall.operands().size() == 1) { + RexRN operand = aggCall.operands().get(0); + if (operand instanceof RexRN.Proj proj) { + // Reference the proj variable if it exists + String projVar = currentEnv.bindings().getOrDefault(proj.operator().getName(), null); + if (projVar != null) { + aggPatterns = aggPatterns.appended("$" + projVar + ":*"); + hasProjOperand = true; + continue; + } + } + } String aggVar = currentEnv.generateVar("agg"); Env aggEnv = currentEnv.addBinding(aggCall.name(), aggVar); aggPatterns = aggPatterns.appended("$" + aggVar + ":*"); currentEnv = aggEnv; } String pattern; - if (aggCalls.size() == 1) { + if (aggCalls.size() == 1 && hasProjOperand) { + // Use the proj reference directly + pattern = aggPatterns.get(0); + return currentEnv.setPattern(pattern).focus(pattern); + } else if (aggCalls.size() == 1) { String aggVar = currentEnv.generateVar("aggregations"); Env boundEnv = currentEnv.addBinding("aggregations", aggVar); pattern = "$" + aggVar + ":*"; @@ -257,6 +599,17 @@ public Env onMatchProj(Env env, RexRN.Proj proj) { } public Env onMatchGroupBy(Env env, RexRN.GroupBy groupBy) { + // Check if GroupBy wraps a Proj expression (for AggregateProjectMerge) + if (groupBy.sources().size() == 1) { + RexRN innerExpr = groupBy.sources().get(0); + if (innerExpr instanceof RexRN.Proj proj) { + // Bind the proj operator name to reference the proj variable + String projVar = env.bindings().getOrDefault(proj.operator().getName(), null); + if (projVar != null) { + return env.focus("$" + projVar + ":*"); + } + } + } String varName = env.generateVar("groupBy"); return env.addBinding(groupBy.operator().getName(), varName) .focus("$" + varName + ":*"); @@ -291,8 +644,8 @@ private String buildNestedAndPattern(Seq operands) { public Env onMatchTrue(Env env, RexRN literal) { String varName = env.generateVar("true"); return env.addBinding("true_" + System.identityHashCode(literal), varName) - .focus("$" + varName + ":(True)") - .setPattern("$" + varName + ":(True)"); + .focus("$" + varName + ":True") + .setPattern("$" + varName + ":True"); } @Override @@ -318,12 +671,26 @@ public Env transformScan(Env env, RelRN.Scan scan) { @Override public Env transformFilter(Env env, RelRN.Filter filter) { + // Check for PruneEmptyFilter pattern + if (env.bindings().containsKey("isPruneEmptyFilter")) { + String inputVar = env.bindings().get("pruneEmptyInput"); + String pattern = "$" + inputVar; + return env.setPattern(pattern).focus(pattern); + } + if (filter.cond() instanceof RexRN.True) { return transform(env, filter.source()); } if (filter.source() instanceof RelRN.Empty) { return transform(env, filter.source()); } + if (filter.cond() instanceof RexRN.False) { + // FilterReduceFalse: transform to ConstructEmptyValues + Env sourceEnv = transform(env, filter.source()); + String sourcePattern = sourceEnv.current(); + String pattern = "(ConstructEmptyValues (OutputCols " + sourcePattern + "))"; + return sourceEnv.setPattern(pattern).focus(pattern); + } Env sourceEnv = transform(env, filter.source()); String sourcePattern = sourceEnv.current(); Env condEnv = transform(sourceEnv, filter.cond()); @@ -342,6 +709,12 @@ public Env transformFilter(Env env, RelRN.Filter filter) { @Override public Env transformProject(Env env, RelRN.Project project) { + // If input is known to have zero rows, return the input reference + if (env.bindings().containsKey("hasZeroRows")) { + String inputVar = env.bindings().getOrDefault("zeroInput", "input"); + String pattern = "$" + inputVar; + return env.setPattern(pattern).focus(pattern); + } if (env.rulename.equals("ProjectMerge")) { String pattern = "(Project\n $input_1\n (MergeProjections\n $proj_0\n $proj_2\n $passthrough_4\n )\n (DifferenceCols\n $innerPassthrough_3\n (ProjectionCols $proj_2)\n )\n)"; return env.setPattern(pattern).focus(pattern); @@ -350,16 +723,19 @@ public Env transformProject(Env env, RelRN.Project project) { String sourcePattern = sourceEnv.current(); Env projEnv = transform(sourceEnv, project.map()); String projPattern = projEnv.current(); - String passthroughVar = projEnv.bindings().get("passthrough"); - if (passthroughVar == null) { - passthroughVar = "passthrough"; - } + String passthroughVar = projEnv.bindings().getOrDefault("passthrough", "passthrough"); String pattern = "(Project\n " + sourcePattern + "\n " + projPattern + "\n $" + passthroughVar + "\n)"; return projEnv.setPattern(pattern).focus(pattern); } @Override public Env transformUnion(Env env, RelRN.Union union) { + // If onMatch indicated zero rows, construct empty using the first zero-input's schema + if (env.bindings().containsKey("hasZeroRows")) { + String leftVar = env.bindings().getOrDefault("zeroInput", "input"); + String pattern = "(ConstructEmptyValues (OutputCols $" + leftVar + "))"; + return env.setPattern(pattern).focus(pattern); + } Env currentEnv = env; Seq sourcePatterns = Seq.empty(); for (RelRN source : union.sources()) { @@ -391,6 +767,13 @@ private String buildNestedUnionTransform(String unionType, Seq sources, @Override public Env transformIntersect(Env env, RelRN.Intersect intersect) { + // Check for PruneEmptyIntersect pattern + if (env.bindings().containsKey("isPruneEmptyIntersect")) { + String leftVar = env.bindings().get("pruneEmptyLeft"); + String pattern = "(ConstructEmptyValues (OutputCols $" + leftVar + "))"; + return env.setPattern(pattern).focus(pattern); + } + Env currentEnv = env; Seq sourcePatterns = Seq.empty(); for (RelRN source : intersect.sources()) { @@ -427,18 +810,128 @@ private String buildNestedIntersectTransform(String intersectType, Seq s return "(" + intersectType + "\n " + first + "\n " + nested + "\n $" + privateVar + "\n)"; } + @Override + public Env transformMinus(Env env, RelRN.Minus minus) { + // Transform for MinusMerge using generic base names; numbering will be applied in translate() + String pattern = "(Except\n" + + " $left\n" + + " (Union\n" + + " $rightB\n" + + " $rightC\n" + + " (MakeUnionPrivateForExcept $pInner $pOuter)\n" + + " )\n" + + " $pOuter\n" + + ")"; + return env.setPattern(pattern).focus(pattern); + } + @Override public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + // Transform for nested LeftJoin removal if binding variables are available + if (env.bindings().containsKey("left") && env.bindings().containsKey("right") + && env.bindings().containsKey("topOn") && env.bindings().containsKey("topPrivate") + && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingCols") + && env.bindings().containsKey("ordering")) { + String leftVar = env.bindings().get("left"); + String rightVar = env.bindings().get("right"); + String topOnVar = env.bindings().get("topOn"); + String topPrivateVar = env.bindings().get("topPrivate"); + String aggsVar = env.bindings().get("aggregations"); + String groupingColsVar = env.bindings().get("groupingCols"); + String orderingVar = env.bindings().get("ordering"); + String pattern = "(DistinctOn\n" + + " (LeftJoin $" + leftVar + " $" + rightVar + " $" + topOnVar + " $" + topPrivateVar + ")\n" + + " $" + aggsVar + "\n" + + " (MakeGrouping\n" + + " $" + groupingColsVar + "\n" + + " (PruneOrdering\n" + + " $" + orderingVar + "\n" + + " (UnionCols\n" + + " (OutputCols $" + leftVar + ")\n" + + " (OutputCols $" + rightVar + ")\n" + + " )\n" + + " )\n" + + " )\n" + + ")"; + return env.setPattern(pattern).focus(pattern); + } + // Transform for Aggregate over LeftJoin: drop join and adjust grouping + // Use presence of match bindings (left, aggregations, groupingCols, ordering) to drive transform + if (env.bindings().containsKey("left") && env.bindings().containsKey("aggregations") + && env.bindings().containsKey("groupingCols") && env.bindings().containsKey("ordering")) { + String leftVar = env.bindings().get("left"); + String aggsVar = env.bindings().get("aggregations"); + String groupingColsVar = env.bindings().get("groupingCols"); + String orderingVar = env.bindings().get("ordering"); + String head = "DistinctOn"; + String pattern = "(" + head + "\n" + + " $" + leftVar + "\n" + + " $" + aggsVar + "\n" + + " (MakeGrouping\n" + + " $" + groupingColsVar + "\n" + + " (PruneOrdering $" + orderingVar + " (OutputCols $" + leftVar + "))\n" + + " )\n" + + ")"; + return env.setPattern(pattern).focus(pattern); + } + // Check if we had a Project source (for AggregateProjectMerge) + String innerInput = env.bindings().getOrDefault("innerInput", null); + if (innerInput != null) { + // Use innerInput instead of transforming the Project + Env groupingEnv = transformGroupSet(env, aggregate.groupSet()); + Env aggsEnv = transformAggCalls(groupingEnv, aggregate.aggCalls()); + String aggsPattern = aggsEnv.current(); + String privateVar = aggsEnv.bindings().getOrDefault("aggregate_private", "private"); + String aggregateType = determineAggregateType(aggregate); + // Use the actual operator name in the after transformation + String pattern = "(" + aggregateType + " $" + innerInput + " " + aggsPattern + " $" + privateVar + ")"; + return aggsEnv.setPattern(pattern).focus(pattern); + } + + // Check if this is an AggregateExtractProject pattern + // This happens when the aggregate has projection expressions that need to be extracted + if (env.bindings().containsKey("isAggregateExtractProject")) { + Env sourceEnv = transform(env, aggregate.source()); + String sourcePattern = sourceEnv.current(); + Env aggsEnv = transformAggCalls(sourceEnv, aggregate.aggCalls()); + String aggsPattern = aggsEnv.current(); + Env groupingEnv = transformGroupSet(aggsEnv, aggregate.groupSet()); + String groupingPattern = groupingEnv.current(); + String privateVar = groupingEnv.bindings().getOrDefault("aggregate_private", "private"); + String aggregateType = determineAggregateType(aggregate); + String pattern = "(" + aggregateType + "\n (Project\n $input\n []\n (UnionCols\n (GroupingCols $groupingPrivate)\n (AggregationOuterCols $aggregations)\n )\n )\n $aggregations\n $groupingPrivate\n)"; + return groupingEnv.setPattern(pattern).focus(pattern); + } + Env sourceEnv = transform(env, aggregate.source()); String sourcePattern = sourceEnv.current(); - Env aggsEnv = transformAggCalls(sourceEnv, aggregate.aggCalls()); + Env groupingEnv = transformGroupSet(sourceEnv, aggregate.groupSet()); + Env aggsEnv = transformAggCalls(groupingEnv, aggregate.aggCalls()); String aggsPattern = aggsEnv.current(); - Env groupingEnv = transformGroupSet(aggsEnv, aggregate.groupSet()); - String groupingPattern = groupingEnv.current(); - String privateVar = groupingEnv.bindings().getOrDefault("aggregate_private", "private"); + String privateVar = aggsEnv.bindings().getOrDefault("aggregate_private", "private"); String aggregateType = determineAggregateType(aggregate); String pattern = "(" + aggregateType + "\n " + sourcePattern + "\n " + aggsPattern + "\n $" + privateVar + "\n)"; - return groupingEnv.setPattern(pattern).focus(pattern); + return aggsEnv.setPattern(pattern).focus(pattern); + } + + private boolean hasProjectionExpressionsInAggregate(RelRN.Aggregate aggregate) { + // Check if any grouping expressions are projections + for (RexRN groupExpr : aggregate.groupSet()) { + if (groupExpr instanceof RexRN.Proj) { + return true; + } + } + + // Check if any aggregation expressions are projections + for (RelRN.AggCall aggCall : aggregate.aggCalls()) { + for (RexRN operand : aggCall.operands()) { + if (operand instanceof RexRN.Proj) { + return true; + } + } + } + + return false; } private Env transformAggCalls(Env env, Seq aggCalls) { @@ -473,7 +966,26 @@ private Env transformGroupSet(Env env, Seq groupSet) { @Override public Env transformEmpty(Env env, RelRN.Empty empty) { - String pattern = "(Values)"; + // If upstream matched an operator with zero rows input + if (env.bindings().containsKey("hasZeroRows")) { + String inputVar = env.bindings().getOrDefault("zeroInput", "input"); + // Check if pattern contains Union indicators (has both left and right) + String patternStr = env.pattern(); + if (patternStr != null && patternStr.contains("Union") && (patternStr.contains("$left") || inputVar.startsWith("left"))) { + // For Union with zero rows, construct empty with left's schema + String pattern = "(ConstructEmptyValues (OutputCols $" + inputVar + "))"; + return env.setPattern(pattern).focus(pattern); + } + // For other cases (like Project), just return the input + String pattern = "$" + inputVar; + return env.setPattern(pattern).focus(pattern); + } + if (env.bindings().containsKey("isPruneEmptyFilter")) { + String inputVar = env.bindings().get("pruneEmptyInput"); + String pattern = "$" + inputVar; + return env.setPattern(pattern).focus(pattern); + } + String pattern = "(ConstructEmptyValues (OutputCols $input_0))"; return env.setPattern(pattern).focus(pattern); } @@ -489,25 +1001,30 @@ public Env transformField(Env env, RexRN.Field field) { @Override public Env transformPred(Env env, RexRN.Pred pred) { - String varName = env.bindings().get(pred.operator().getName()); - if (varName == null) { - varName = "cond"; - } + String varName = env.bindings().getOrDefault(pred.operator().getName(), "cond"); String pattern = "$" + varName; return env.setPattern(pattern).focus(pattern); } @Override public Env transformProj(Env env, RexRN.Proj proj) { - String varName = env.bindings().get(proj.operator().getName()); - if (varName == null) { - varName = "proj"; - } + String varName = env.bindings().getOrDefault(proj.operator().getName(), "proj"); String pattern = "$" + varName; return env.setPattern(pattern).focus(pattern); } public Env transformGroupBy(Env env, RexRN.GroupBy groupBy) { + // Check if GroupBy wraps a Proj expression (for AggregateProjectMerge) + if (groupBy.sources().size() == 1) { + RexRN innerExpr = groupBy.sources().get(0); + if (innerExpr instanceof RexRN.Proj proj) { + // Reference the proj variable + String projVar = env.bindings().get(proj.operator().getName()); + if (projVar != null) { + return env.setPattern("$" + projVar).focus("$" + projVar); + } + } + } String varName = env.bindings().get(groupBy.operator().getName()); if (varName == null) { varName = "groupBy"; @@ -562,12 +1079,114 @@ public Env transformCustom(Env env, RexRN custom) { public String translate(String name, Env onMatch, Env transform) { StringBuilder sb = new StringBuilder(); sb.append("[").append(name).append(", Normalize]\n"); - sb.append(onMatch.pattern()).append("\n"); + String match = onMatch.pattern(); + // Normalize Union with HasZeroRows patterns: remove private field and Values references + if (match.contains("HasZeroRows") && (match.startsWith("(Union") || match.startsWith("(UnionAll"))) { + // Remove private field if present + match = match.replaceAll("\\s+\\$private_\\d+:\\*\\s*\\)", "\n)"); + match = match.replaceAll("\\s+\\$private_\\d+:\\*\\)", ")"); + } + // Also handle Union patterns that haven't been normalized yet + else if (match.startsWith("(Union\n")) { + String[] lines = match.split("\n"); + if (lines.length >= 3 && lines[1].contains(":(Values)") && lines[2].contains(":(Values)")) { + String leftVar = extractVar(lines[1]); + String rightVar = extractVar(lines[2]); + String unionType = "Union"; + if (lines[0].startsWith("(UnionAll")) unionType = "UnionAll"; + match = "(" + unionType + "\n $" + leftVar + ":* & (HasZeroRows $" + leftVar + ")\n $" + rightVar + ":* & (HasZeroRows $" + rightVar + ")\n)"; + } + } + sb.append(match).append("\n"); sb.append("=>\n"); - sb.append(transform.pattern()).append("\n"); + String out = transform.pattern(); + // Map generic $input to the first source var found in match (excluding private) + if (out.equals("$input") || out.startsWith("(ConstructEmptyValues (OutputCols $input")) { + String numbered = findFirstVar(match); + if (numbered != null) { + out = out.replace("$input_0", "$" + numbered).replace("$input", "$" + numbered); + } + } + // If match has HasZeroRows pattern and output is ConstructEmptyValues, extract left variable from match + if (match.contains("HasZeroRows") && match.contains("$left") && out.contains("ConstructEmptyValues")) { + // Extract the left variable name from the match pattern (e.g., $left_0) + int leftIdx = match.indexOf("$left"); + if (leftIdx >= 0) { + int start = leftIdx + 1; // skip $ + int end = start; + while (end < match.length() && (Character.isLetterOrDigit(match.charAt(end)) || match.charAt(end) == '_')) end++; + String leftVar = match.substring(start, end); + // Replace any variable in OutputCols with the actual left variable from match + out = out.replaceAll("(OutputCols \\$)[a-zA-Z_][a-zA-Z0-9_]*", "$1" + leftVar); + } + } + // Align unnumbered variables in output to numbered variables from match + java.util.Map varMap = extractNumberedVarMap(match); + if (!varMap.isEmpty()) { + for (java.util.Map.Entry e : varMap.entrySet()) { + String base = e.getKey(); + String numbered = e.getValue(); // includes leading $ + // Replace standalone $base (not already numbered) in output + out = out.replaceAll( + "\\$" + java.util.regex.Pattern.quote(base) + "(?![_0-9])", + java.util.regex.Matcher.quoteReplacement(numbered) + ); + } + } + sb.append(out).append("\n"); return sb.toString(); } + private static String extractVar(String line) { + int i = line.indexOf('$'); + if (i < 0) return null; + int j = i + 1; + while (j < line.length() && (Character.isLetterOrDigit(line.charAt(j)) || line.charAt(j) == '_')) j++; + return line.substring(i + 1, j); + } + + private static String findFirstVar(String match) { + for (String line : match.split("\n")) { + if (line.contains("$private")) continue; + if (line.contains("$")) { + String var = extractVar(line); + if (var != null) return var; + } + } + return null; + } + + private static java.util.Map extractNumberedVarMap(String match) { + java.util.Map map = new java.util.HashMap<>(); + java.util.regex.Matcher m = java.util.regex.Pattern.compile("\\$([A-Za-z][A-Za-z0-9_]*)_([0-9]+)").matcher(match); + while (m.find()) { + String base = m.group(1); + String numbered = "$" + base + "_" + m.group(2); + map.putIfAbsent(base, numbered); + } + return map; + } + + @Override + public Env preTransform(Env env) { + // If the onMatch pattern signaled HasZeroRows, propagate a generic binding + String p = env.pattern(); + if (p != null) { + int idx = p.indexOf("(HasZeroRows $"); + if (idx >= 0) { + int start = idx + "(HasZeroRows $".length(); + int end = p.indexOf(")", start); + if (end > start) { + String var = p.substring(start, end).trim(); + if (!var.isEmpty()) { + return env.addBinding("hasZeroRows", "true").addBinding("zeroInput", var); + } + } + } + } + return env; + } + private String getJoinType(org.apache.calcite.rel.core.JoinRelType joinType) { return switch (joinType) { case INNER -> "InnerJoin"; diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java b/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java index 567eee6..184c58e 100644 --- a/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java @@ -57,7 +57,9 @@ public static Seq ruleList() { if (files != null) { for (java.io.File file : files) { String className = file.getName().replace(".java", ""); - if (className.contains("Distinct") || className.contains("Extract") || className.contains("Pull") || className.contains("False") || className.contains("Prune") || className.contains("Minus") || className.contains("AggregateMerge")) { + if (className.contains("Distinct") || className.contains("Pull") || + className.contains("JoinConditionPush") || className.contains("ProjectAggregateMerge") || + className.contains("AggregativeJoinRemove") || className.contains("AggregateProjectConstantToDummyJoin") || className.contains("AggregateProjectMerge")) { continue; } diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachTests b/src/main/java/org/qed/Backends/Cockroach/CockroachTests index 1ca9357..cec8170 100644 --- a/src/main/java/org/qed/Backends/Cockroach/CockroachTests +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachTests @@ -68,7 +68,7 @@ CREATE ROLE alice; # -------------------------------------------------- # FilterMerge # -------------------------------------------------- -norm expect=FilterMerge +norm expect=FilterMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT * FROM a WHERE k=3) WHERE s='foo' ---- select @@ -84,7 +84,7 @@ select ├── k:1 = 3 [outer=(1), constraints=(/1: [/3 - /3]; tight), fd=()-->(1)] └── s:4 = 'foo' [outer=(4), constraints=(/4: [/'foo' - /'foo']; tight), fd=()-->(4)] -norm expect=FilterMerge +norm expect=FilterMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT * FROM a WHERE k=3) WHERE s='foo' ---- select @@ -100,7 +100,7 @@ select ├── k:1 = 3 [outer=(1), constraints=(/1: [/3 - /3]; tight), fd=()-->(1)] └── s:4 = 'foo' [outer=(4), constraints=(/4: [/'foo' - /'foo']; tight), fd=()-->(4)] -norm expect=FilterMerge +norm expect=FilterMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT * FROM a WHERE i=1) WHERE False ---- values @@ -109,7 +109,7 @@ values ├── key: () └── fd: ()-->(1-5) -norm expect=FilterMerge +norm expect=FilterMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT * FROM a WHERE i<5) WHERE s='foo' ---- select @@ -124,7 +124,7 @@ select ├── i:2 < 5 [outer=(2), constraints=(/2: (/NULL - /4]; tight)] └── s:4 = 'foo' [outer=(4), constraints=(/4: [/'foo' - /'foo']; tight), fd=()-->(4)] -norm expect=FilterMerge +norm expect=FilterMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT * FROM a WHERE i>1 AND i<10) WHERE s='foo' OR k=5 ---- select @@ -139,7 +139,7 @@ select ├── (i:2 > 1) AND (i:2 < 10) [outer=(2), constraints=(/2: [/2 - /9]; tight)] └── (s:4 = 'foo') OR (k:1 = 5) [outer=(1,4)] -norm expect=FilterIntoJoin +norm expect=FilterIntoJoin disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM a INNER JOIN b ON a.k=b.x WHERE a.s='foo' ---- inner-join (hash) @@ -164,7 +164,7 @@ inner-join (hash) └── filters └── k:1 = x:8 [outer=(1,8), constraints=(/1: (/NULL - ]; /8: (/NULL - ]), fd=(1)==(8), (8)==(1)] -norm expect=FilterProjectTranspose disable=ProjectFilterTranspose +norm expect=FilterProjectTranspose disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT i, i+1 AS r, f FROM a) a WHERE f=10.0 ---- project @@ -181,7 +181,7 @@ project └── projections └── i:2 + 1 [as=r:8, outer=(2), immutable] -norm expect=SemiJoinFilterTranspose disable=ProjectFilterTranspose +norm expect=SemiJoinFilterTranspose disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,UnionMerge) SELECT * FROM xy WHERE EXISTS (SELECT 1 FROM uv WHERE uv.u = xy.x) AND xy.y > 10 ---- semi-join (hash) @@ -204,7 +204,7 @@ semi-join (hash) └── filters └── u:5 = x:1 [outer=(1,5), constraints=(/1: (/NULL - ]; /5: (/NULL - ]), fd=(1)==(5), (5)==(1)] -norm expect=ProjectFilterTranspose disable=FilterProjectTranspose +norm expect=ProjectFilterTranspose disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM (SELECT i FROM a WHERE i=100) a ---- select @@ -215,7 +215,7 @@ select └── filters └── i:2 = 100 [outer=(2), constraints=(/2: [/100 - /100]; tight), fd=()-->(2)] -norm expect=FilterReduceTrue +norm expect=FilterReduceTrue disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM a WHERE True ---- scan a @@ -227,7 +227,7 @@ exec-ddl CREATE INDEX partial_idx ON a (s) WHERE true ---- -norm expect=FilterReduceTrue +norm expect=FilterReduceTrue disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM a ---- scan a @@ -241,7 +241,7 @@ exec-ddl DROP INDEX partial_idx ---- -norm expect=UnionMerge disable=ConvertUnionToDistinctUnionAll +norm expect=UnionMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,ConvertUnionToDistinctUnionAll) SELECT a, b, c FROM (SELECT a, b, c FROM t WHERE a < 0) UNION @@ -291,7 +291,7 @@ union └── t.b:15 > 1000 [outer=(15), constraints=(/15: [/1001 - ]; tight)] -norm expect=JoinPushTransitivePredicates disable=FilterIntoJoin +norm expect=JoinPushTransitivePredicates disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) SELECT * FROM a INNER JOIN b ON a.k = b.x WHERE a.s='foo' ---- inner-join (hash) @@ -316,3 +316,268 @@ inner-join (hash) └── filters └── k:1 = x:8 [outer=(1,8), constraints=(/1: (/NULL - ]; /8: (/NULL - ]), fd=(1)==(8), (8)==(1)] + + +exec-ddl +CREATE TABLE sales (id INT PRIMARY KEY, category1 STRING, category2 STRING, amount DECIMAL) +---- + +exec-ddl +CREATE TABLE emp (empno INT PRIMARY KEY, ename STRING, job STRING, mgr INT, hiredate DATE, sal DECIMAL, comm DECIMAL, deptno INT) +---- + +exec-ddl +CREATE TABLE dept (deptno INT PRIMARY KEY, dname STRING, loc STRING) +---- + +norm expect=AggregateFilterTranspose disable=(FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge,PushSelectIntoGroupBy) +SELECT category1, category2, SUM(amount) FROM (SELECT * FROM sales WHERE category1 = category2) GROUP BY category1, category2 +---- +select + ├── columns: category1:2!null category2:3!null sum:7 + ├── key: (3) + ├── fd: (2,3)-->(7), (2)==(3), (3)==(2) + ├── group-by (hash) + │ ├── columns: category1:2 category2:3 sum:7 + │ ├── grouping columns: category1:2 category2:3 + │ ├── key: (2,3) + │ ├── fd: (2,3)-->(7) + │ ├── scan sales + │ │ └── columns: category1:2 category2:3 amount:4 + │ └── aggregations + │ └── sum [as=sum:7, outer=(4)] + │ └── amount:4 + └── filters + └── category1:2 = category2:3 [outer=(2,3), constraints=(/2: (/NULL - ]; /3: (/NULL - ]), fd=(2)==(3), (3)==(2)] + +norm expect=FilterAggregateTranspose disable=(AggregateFilterTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM (SELECT category1, category2, SUM(amount) FROM sales GROUP BY category1, category2) WHERE category1 = category2 +---- +group-by (hash) + ├── columns: category1:2!null category2:3!null sum:7 + ├── grouping columns: category2:3!null + ├── key: (3) + ├── fd: (3)-->(2,7), (2)==(3), (3)==(2) + ├── select + │ ├── columns: category1:2!null category2:3!null amount:4 + │ ├── fd: (2)==(3), (3)==(2) + │ ├── scan sales + │ │ └── columns: category1:2 category2:3 amount:4 + │ └── filters + │ └── category1:2 = category2:3 [outer=(2,3), constraints=(/2: (/NULL - ]; /3: (/NULL - ]), fd=(2)==(3), (3)==(2)] + └── aggregations + ├── sum [as=sum:7, outer=(4)] + │ └── amount:4 + └── const-agg [as=category1:2, outer=(2)] + └── category1:2 + +norm expect=ProjectMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,SemiJoinFilterTranspose,UnionMerge,EliminateProject) +SELECT deptno, ename, sal * 2 AS doubled_sal FROM (SELECT deptno, ename, sal, comm FROM emp) +---- +project + ├── columns: deptno:8 ename:2 doubled_sal:11 + ├── immutable + ├── scan emp + │ └── columns: ename:2 sal:6 deptno:8 + └── projections + └── sal:6 * 2 [as=doubled_sal:11, outer=(6), immutable] + +# Additional test cases using emp/dept/sales tables +norm expect=FilterMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM (SELECT * FROM emp WHERE sal > 1000) WHERE deptno = 10 +---- +select + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8!null + ├── immutable + ├── key: (1) + ├── fd: ()-->(8), (1)-->(2-7) + ├── scan emp + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 deptno:8 + │ ├── key: (1) + │ └── fd: (1)-->(2-8) + └── filters + ├── sal:6 > 1000 [outer=(6), immutable, constraints=(/6: (/1000 - ]; tight)] + └── deptno:8 = 10 [outer=(8), constraints=(/8: [/10 - /10]; tight), fd=()-->(8)] + +norm expect=FilterIntoJoin disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno WHERE emp.sal > 2000 +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── immutable + ├── key: (1) + ├── fd: (1)-->(2-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── select + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 emp.deptno:8 + │ ├── immutable + │ ├── key: (1) + │ ├── fd: (1)-->(2-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ └── filters + │ └── sal:6 > 2000 [outer=(6), immutable, constraints=(/6: (/2000 - ]; tight)] + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +norm expect=FilterProjectTranspose disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM (SELECT empno, sal * 1.1 AS new_sal, job FROM emp) WHERE empno > 100 +---- +project + ├── columns: empno:1!null new_sal:11 job:3 + ├── immutable + ├── key: (1) + ├── fd: (1)-->(3,11) + ├── select + │ ├── columns: empno:1!null job:3 sal:6 + │ ├── key: (1) + │ ├── fd: (1)-->(3,6) + │ ├── scan emp + │ │ ├── columns: empno:1!null job:3 sal:6 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(3,6) + │ └── filters + │ └── empno:1 > 100 [outer=(1), constraints=(/1: [/101 - ]; tight)] + └── projections + └── sal:6 * 1.1 [as=new_sal:11, outer=(6), immutable] + +norm expect=SemiJoinFilterTranspose disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,UnionMerge) +SELECT * FROM emp WHERE EXISTS (SELECT 1 FROM dept WHERE dept.deptno = emp.deptno) AND emp.sal > 1500 +---- +semi-join (hash) + ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 deptno:8 + ├── immutable + ├── key: (1) + ├── fd: (1)-->(2-8) + ├── select + │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6!null comm:7 emp.deptno:8 + │ ├── immutable + │ ├── key: (1) + │ ├── fd: (1)-->(2-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ └── filters + │ └── sal:6 > 1500 [outer=(6), immutable, constraints=(/6: (/1500 - ]; tight)] + ├── scan dept + │ ├── columns: dept.deptno:11!null + │ └── key: (11) + └── filters + └── dept.deptno:11 = emp.deptno:8 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + +norm expect=ProjectFilterTranspose disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM (SELECT empno, job FROM emp WHERE job = 'MANAGER') WHERE empno > 100 +---- +select + ├── columns: empno:1!null job:3!null + ├── key: (1) + ├── fd: ()-->(3) + ├── scan emp + │ ├── columns: empno:1!null job:3 + │ ├── key: (1) + │ └── fd: (1)-->(3) + └── filters + ├── job:3 = 'MANAGER' [outer=(3), constraints=(/3: [/'MANAGER' - /'MANAGER']; tight), fd=()-->(3)] + └── empno:1 > 100 [outer=(1), constraints=(/1: [/101 - ]; tight)] + +norm expect=FilterReduceTrue disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM sales WHERE True +---- +scan sales + ├── columns: id:1!null category1:2 category2:3 amount:4 + ├── key: (1) + └── fd: (1)-->(2-4) + +norm expect=UnionMerge disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,JoinPushTransitivePredicates,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,ConvertUnionToDistinctUnionAll) +SELECT deptno, dname FROM + (SELECT deptno, dname FROM dept WHERE loc = 'NEW YORK') +UNION + (SELECT deptno, dname FROM dept WHERE loc = 'CHICAGO') +UNION + (SELECT deptno, dname FROM dept WHERE loc = 'BOSTON') +---- +union + ├── columns: deptno:18 dname:19 + ├── left columns: deptno:11 dname:12 + ├── right columns: dept.deptno:13 dept.dname:14 + ├── key: (18,19) + ├── project + │ ├── columns: dept.deptno:1!null dept.dname:2 + │ ├── key: (1) + │ ├── fd: (1)-->(2) + │ └── select + │ ├── columns: dept.deptno:1!null dept.dname:2 loc:3!null + │ ├── key: (1) + │ ├── fd: ()-->(3), (1)-->(2) + │ ├── scan dept + │ │ ├── columns: dept.deptno:1!null dept.dname:2 loc:3 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2,3) + │ └── filters + │ └── loc:3 = 'NEW YORK' [outer=(3), constraints=(/3: [/'NEW YORK' - /'NEW YORK']; tight), fd=()-->(3)] + └── union + ├── columns: deptno:11 dname:12 + ├── left columns: dept.deptno:1 dept.dname:2 + ├── right columns: dept.deptno:6 dept.dname:7 + ├── key: (11,12) + ├── project + │ ├── columns: dept.deptno:6!null dept.dname:7 + │ ├── key: (6) + │ ├── fd: (6)-->(7) + │ └── select + │ ├── columns: dept.deptno:6!null dept.dname:7 loc:8!null + │ ├── key: (6) + │ ├── fd: ()-->(8), (6)-->(7) + │ ├── scan dept + │ │ ├── columns: dept.deptno:6!null dept.dname:7 loc:8 + │ │ ├── key: (6) + │ │ └── fd: (6)-->(7,8) + │ └── filters + │ └── loc:8 = 'CHICAGO' [outer=(8), constraints=(/8: [/'CHICAGO' - /'CHICAGO']; tight), fd=()-->(8)] + └── project + ├── columns: dept.deptno:13!null dept.dname:14 + ├── key: (13) + ├── fd: (13)-->(14) + └── select + ├── columns: dept.deptno:13!null dept.dname:14 loc:15!null + ├── key: (13) + ├── fd: ()-->(15), (13)-->(14) + ├── scan dept + │ ├── columns: dept.deptno:13!null dept.dname:14 loc:15 + │ ├── key: (13) + │ └── fd: (13)-->(14,15) + └── filters + └── loc:15 = 'BOSTON' [outer=(15), constraints=(/15: [/'BOSTON' - /'BOSTON']; tight), fd=()-->(15)] + +norm expect=JoinPushTransitivePredicates disable=(AggregateFilterTranspose,FilterAggregateTranspose,FilterIntoJoin,FilterMerge,FilterProjectTranspose,FilterReduceTrue,FilterSetOpTranspose,IntersectMerge,JoinAddRedundantSemiJoin,JoinCommute,ProjectFilterTranspose,ProjectMerge,SemiJoinFilterTranspose,UnionMerge) +SELECT * FROM emp INNER JOIN dept ON emp.deptno = dept.deptno WHERE emp.job = 'MANAGER' +---- +inner-join (hash) + ├── columns: empno:1!null ename:2 job:3!null mgr:4 hiredate:5 sal:6 comm:7 deptno:8!null deptno:11!null dname:12 loc:13 + ├── multiplicity: left-rows(zero-or-one), right-rows(zero-or-more) + ├── key: (1) + ├── fd: ()-->(3), (1)-->(2,4-8), (11)-->(12,13), (8)==(11), (11)==(8) + ├── select + │ ├── columns: empno:1!null ename:2 job:3!null mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ ├── key: (1) + │ ├── fd: ()-->(3), (1)-->(2,4-8) + │ ├── scan emp + │ │ ├── columns: empno:1!null ename:2 job:3 mgr:4 hiredate:5 sal:6 comm:7 emp.deptno:8 + │ │ ├── key: (1) + │ │ └── fd: (1)-->(2-8) + │ └── filters + │ └── job:3 = 'MANAGER' [outer=(3), constraints=(/3: [/'MANAGER' - /'MANAGER']; tight), fd=()-->(3)] + ├── scan dept + │ ├── columns: dept.deptno:11!null dname:12 loc:13 + │ ├── key: (11) + │ └── fd: (11)-->(12,13) + └── filters + └── emp.deptno:8 = dept.deptno:11 [outer=(8,11), constraints=(/8: (/NULL - ]; /11: (/NULL - ]), fd=(8)==(11), (11)==(8)] + diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt new file mode 100644 index 0000000..58dbc8e --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateExtractProject.opt @@ -0,0 +1,19 @@ +[AggregateExtractProject, Normalize] +(GroupBy + $input_5:* + $aggregations_6:* + $groupingPrivate_7:* +) +=> +(GroupBy + (Project + $input_5 + [] + (UnionCols + (GroupingCols $groupingPrivate_7) + (AggregationOuterCols $aggregations_6) + ) + ) + $aggregations_6 + $groupingPrivate_7 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt new file mode 100644 index 0000000..e914ddc --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinJoinRemove.opt @@ -0,0 +1,54 @@ +[AggregateJoinJoinRemove, Normalize] +(DistinctOn + $topJoin_0:(LeftJoin + $bottomJoin_1:(LeftJoin $left_2:* $middle_3:* * *) & + (JoinPreservesLeftRows $bottomJoin_1) & + (JoinDoesNotDuplicateLeftRows $bottomJoin_1) + $right_4:* + $topOn_5:* + $topPrivate_6:* + ) & + (JoinPreservesLeftRows $topJoin_0) & + (JoinDoesNotDuplicateLeftRows $topJoin_0) + $aggregations_7:[] + $groupingPrivate_8:(GroupingPrivate $groupingCols_9:* $ordering_10:*) & + (ColsAreSubset + (UnionCols + $groupingCols_9 + (AggregationOuterCols $aggregations_7) + ) + (UnionCols + (OutputCols $left_2) + (OutputCols $right_4) + ) + ) & + ^(ColsIntersect + (UnionCols + $groupingCols_9 + (AggregationOuterCols $aggregations_7) + ) + (OutputCols $middle_3) + ) & + (OrderingCanProjectCols + $ordering_10 + (UnionCols + (OutputCols $left_2) + (OutputCols $right_4) + ) + ) +) +=> +(DistinctOn + (LeftJoin $left_2 $right_4 $topOn_5 $topPrivate_6) + $aggregations_7 + (MakeGrouping + $groupingCols_9 + (PruneOrdering + $ordering_10 + (UnionCols + (OutputCols $left_2) + (OutputCols $right_4) + ) + ) + ) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt new file mode 100644 index 0000000..a237ce9 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateJoinRemove.opt @@ -0,0 +1,28 @@ +[AggregateJoinRemove, Normalize] +(DistinctOn + $input_0:(LeftJoin $left_1:* * * *) & + (JoinPreservesLeftRows $input_0) & + (JoinDoesNotDuplicateLeftRows $input_0) + $aggregations_2:[] + $groupingPrivate_3:(GroupingPrivate $groupingCols_4:* $ordering_5:*) & + (ColsAreSubset + (UnionCols + $groupingCols_4 + (AggregationOuterCols $aggregations_2) + ) + (OutputCols $left_1) + ) & + (OrderingCanProjectCols + $ordering_5 + (OutputCols $left_1) + ) +) +=> +(DistinctOn + $left_1 + $aggregations_2 + (MakeGrouping + $groupingCols_4 + (PruneOrdering $ordering_5 (OutputCols $left_1)) + ) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt new file mode 100644 index 0000000..f0b0c87 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt @@ -0,0 +1,11 @@ +[FilterReduceFalse, Normalize] +(Select + $input_0:* + $on_1:[ + ... + $item_2:(FiltersItem (False)) + ... + ] +) +=> +(ConstructEmptyValues (OutputCols $input_0_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt new file mode 100644 index 0000000..cce0a04 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinExtractFilter.opt @@ -0,0 +1,17 @@ +[JoinExtractFilter, Normalize] +(InnerJoin + $input_0:* + $input_1:* + $cond_2:* + $private_3:* +) +=> +(Select + (InnerJoin + $input_0 + $input_1 + $true + $private_3 +) + $cond_2 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt new file mode 100644 index 0000000..94b4b6d --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceFalse.opt @@ -0,0 +1,18 @@ +[JoinReduceFalse, Normalize] +(InnerJoin + $input_0:* + $input_1:* + $on_2:[ + ... + $item_3:(FiltersItem (False)) + ... + ] + $private_4:* +) +=> +(InnerJoin + $input_0 + $input_1 + [ (FiltersItem (False)) ] + $private_4 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt new file mode 100644 index 0000000..c894d23 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinReduceTrue.opt @@ -0,0 +1,18 @@ +[JoinReduceTrue, Normalize] +(InnerJoin + $input_0:* + $input_1:* + $on_2:[ + ... + $item_3:(FiltersItem (True)) + ... + ] + $private_4:* +) +=> +(InnerJoin + $input_0 + $input_1 + (RemoveFiltersItem $on_2 $item_3) + $private_4 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt new file mode 100644 index 0000000..5e251b0 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/MinusMerge.opt @@ -0,0 +1,20 @@ +[MinusMerge, Normalize] +(Except + (Except + $left_0:* + $rightB_1:* + $pInner_2:* + ) + $rightC_3:* + $pOuter_4:* +) +=> +(Except + $left_0 + (Union + $rightB_1 + $rightC_3 + (MakeUnionPrivateForExcept $pInner_2 $pOuter_4) + ) + $pOuter_4 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt new file mode 100644 index 0000000..9ebbded --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt @@ -0,0 +1,7 @@ +[PruneEmptyFilter, Normalize] +(Select + $input_2:* & (HasZeroRows $input_2) + $filters_3:* +) +=> +$input_2 diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt new file mode 100644 index 0000000..43bed9c --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyIntersect.opt @@ -0,0 +1,7 @@ +[PruneEmptyIntersect, Normalize] +(Intersect + $left_0:* + $right_1:* & (HasZeroRows $right_1) +) +=> +(ConstructEmptyValues (OutputCols $left_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt new file mode 100644 index 0000000..0221bd1 --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt @@ -0,0 +1,8 @@ +[PruneEmptyMinus, Normalize] +(Except + $empty_0:(Values) + $input_1:* + $private_2:* +) +=> +(ConstructEmptyValues (OutputCols $empty_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt new file mode 100644 index 0000000..c2d646e --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt @@ -0,0 +1,8 @@ +[PruneEmptyProject, Normalize] +(Project + $input_0:* & (HasZeroRows $input_0) + $projections_1:* + $passthrough_2:* +) +=> +$input_0 diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt new file mode 100644 index 0000000..bedccbf --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyUnion.opt @@ -0,0 +1,7 @@ +[PruneEmptyUnion, Normalize] +(Union + $left_0:* & (HasZeroRows $left_0) + $right_1:* & (HasZeroRows $right_1) +) +=> +(ConstructEmptyValues (OutputCols $left_0)) diff --git a/src/main/java/org/qed/RRuleInstances/PruneZeroRowsTable.java b/src/main/java/org/qed/RRuleInstances/PruneZeroRowsTable.java deleted file mode 100644 index a60fe9b..0000000 --- a/src/main/java/org/qed/RRuleInstances/PruneZeroRowsTable.java +++ /dev/null @@ -1,18 +0,0 @@ -package org.qed.RRuleInstances; - -import org.qed.RRule; -import org.qed.RelRN; - -public record PruneZeroRowsTable() implements RRule { - static final RelRN a = RelRN.scan("A", "Common_Type"); - - @Override - public RelRN before() { - return a; - } - - @Override - public RelRN after() { - return a; - } -} From a25f15a145536865bbdecb4c82ceb41eb1238316 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 6 Nov 2025 13:06:05 -0800 Subject: [PATCH 2/2] Second Last Batch of Rules --- .../Cockroach/CockroachGenerator.java | 204 ++++++++++-------- .../Backends/Cockroach/CockroachTester.java | 4 +- .../Generated/AggregateProjectMerge.opt | 16 ++ .../Cockroach/Generated/FilterReduceFalse.opt | 2 +- .../Cockroach/Generated/JoinConditionPush.opt | 17 ++ .../Cockroach/Generated/PruneEmptyFilter.opt | 2 +- .../Cockroach/Generated/PruneEmptyMinus.opt | 2 +- .../Cockroach/Generated/PruneEmptyProject.opt | 4 +- 8 files changed, 150 insertions(+), 101 deletions(-) create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt create mode 100644 src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java b/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java index c2909d0..b873b09 100644 --- a/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachGenerator.java @@ -28,7 +28,6 @@ public Env onMatchFilter(Env env, RelRN.Filter filter) { Env condEnv = onMatch(sourceEnv, filter.cond()); String condPattern; - // Check for PruneEmptyFilter pattern: filter on empty source if (filter.source() instanceof RelRN.Empty) { String inputVar = condEnv.generateVar("input"); String filtersVar = condEnv.generateVar("filters"); @@ -42,7 +41,6 @@ public Env onMatchFilter(Env env, RelRN.Filter filter) { String pattern = "(Select\n " + sourcePattern + "\n []\n)"; return condEnv.setPattern(pattern).focus(pattern); } else if (filter.cond() instanceof RexRN.False) { - // FilterReduceFalse pattern: Select with False condition String onVar = condEnv.generateVar("on"); Env onEnv = condEnv.addBinding("on", onVar); String itemVar = onEnv.generateVar("item"); @@ -57,7 +55,6 @@ public Env onMatchFilter(Env env, RelRN.Filter filter) { } public Env onMatchProject(Env env, RelRN.Project project) { - // Generic handling for Project over empty input (PruneEmptyProject) if (project.source() instanceof RelRN.Empty) { String inputVar = env.generateVar("input"); Env inputEnv = env.addBinding("zeroInput", inputVar) @@ -99,25 +96,23 @@ public Env onMatchProject(Env env, RelRN.Project project) { @Override public Env onMatchJoin(Env env, RelRN.Join join) { - // Check for JoinReduceTrue/JoinReduceFalse patterns if (join.cond() instanceof RexRN.And and) { if (and.sources().size() == 2) { boolean hasTrue = false; boolean hasFalse = false; RexRN otherCond = null; - for (RexRN source : and.sources()) { - if (source instanceof RexRN.True) { + for (RexRN side : and.sources()) { + if (side instanceof RexRN.True) { hasTrue = true; - } else if (source instanceof RexRN.False) { + } else if (side instanceof RexRN.False) { hasFalse = true; } else { - otherCond = source; + otherCond = side; } } if (hasTrue && otherCond != null) { - // JoinReduceTrue pattern: And(cond, True) -> cond Env leftEnv = onMatch(env, join.left()); String leftPattern = leftEnv.current(); Env rightEnv = onMatch(leftEnv, join.right()); @@ -134,7 +129,6 @@ public Env onMatchJoin(Env env, RelRN.Join join) { String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n $" + onVar + ":[\n ...\n $" + itemVar + ":(FiltersItem (True))\n ...\n ]\n $" + privateVar + ":*\n)"; return privateEnv.setPattern(pattern).focus(pattern); } else if (hasFalse && otherCond != null) { - // JoinReduceFalse pattern: And(cond, False) -> False Env leftEnv = onMatch(env, join.left()); String leftPattern = leftEnv.current(); Env rightEnv = onMatch(leftEnv, join.right()); @@ -152,6 +146,24 @@ public Env onMatchJoin(Env env, RelRN.Join join) { return privateEnv.setPattern(pattern).focus(pattern); } } + if (join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER + && and.sources().size() > 2) { + String leftVar = env.generateVar("left"); + Env leftEnv = env.addBinding("left", leftVar); + String rightVar = leftEnv.generateVar("right"); + Env rightEnv = leftEnv.addBinding("right", rightVar); + String onVar = rightEnv.generateVar("on"); + Env onEnv = rightEnv.addBinding("on", onVar); + String privateVar = onEnv.generateVar("private"); + Env privateEnv = onEnv.addBinding("private", privateVar); + String pattern = "(InnerJoin\n" + + " $" + leftVar + ":* & ^(HasOuterCols $" + leftVar + ")\n" + + " $" + rightVar + ":* & ^(HasOuterCols $" + rightVar + ")\n" + + " $" + onVar + ":*\n" + + " $" + privateVar + ":*\n" + + ")"; + return privateEnv.setPattern(pattern).focus(pattern); + } } Env leftEnv = onMatch(env, join.left()); @@ -170,9 +182,7 @@ public Env onMatchJoin(Env env, RelRN.Join join) { @Override public Env transformJoin(Env env, RelRN.Join join) { - // Check for JoinReduceTrue/JoinReduceFalse patterns if (env.bindings().containsKey("joinReduceTrue")) { - // JoinReduceTrue: simplify to RemoveFiltersItem Env leftEnv = transform(env, join.left()); String leftPattern = leftEnv.current(); Env rightEnv = transform(leftEnv, join.right()); @@ -185,7 +195,6 @@ public Env transformJoin(Env env, RelRN.Join join) { String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n (RemoveFiltersItem $" + onVar + " $" + itemVar + ")\n $" + privateVar + "\n)"; return rightEnv.setPattern(pattern).focus(pattern); } else if (env.bindings().containsKey("joinReduceFalse")) { - // JoinReduceFalse: simplify to FiltersItem (False) Env leftEnv = transform(env, join.left()); String leftPattern = leftEnv.current(); Env rightEnv = transform(leftEnv, join.right()); @@ -196,6 +205,28 @@ public Env transformJoin(Env env, RelRN.Join join) { String pattern = "(" + joinType + "\n " + leftPattern + "\n " + rightPattern + "\n [ (FiltersItem (False)) ]\n $" + privateVar + "\n)"; return rightEnv.setPattern(pattern).focus(pattern); } + if (join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.INNER + && env.bindings().containsKey("left") + && env.bindings().containsKey("right") + && env.bindings().containsKey("on") + && env.bindings().containsKey("private") + && !env.bindings().containsKey("joinReduceTrue") + && !env.bindings().containsKey("joinReduceFalse")) { + String leftVar = env.bindings().get("left"); + String rightVar = env.bindings().get("right"); + String onVar = env.bindings().get("on"); + String privateVar = env.bindings().get("private"); + String pattern = "(InnerJoin\n" + + " (Select $" + leftVar + " (ExtractBoundConditions $" + onVar + " (OutputCols $" + leftVar + ")))\n" + + " (Select $" + rightVar + " (ExtractBoundConditions $" + onVar + " (OutputCols $" + rightVar + ")))\n" + + " (ExtractUnboundConditions\n" + + " (ExtractUnboundConditions $" + onVar + " (OutputCols $" + leftVar + "))\n" + + " (OutputCols $" + rightVar + ")\n" + + " )\n" + + " $" + privateVar + "\n" + + ")"; + return env.setPattern(pattern).focus(pattern); + } Env leftEnv = transform(env, join.left()); String leftPattern = leftEnv.current(); @@ -216,7 +247,6 @@ public Env transformJoin(Env env, RelRN.Join join) { @Override public Env onMatchUnion(Env env, RelRN.Union union) { - // If both inputs are Empty, emit a HasZeroRows pattern generically if (union.sources().size() == 2) { RelRN leftSource = union.sources().get(0); RelRN rightSource = union.sources().get(1); @@ -264,7 +294,6 @@ private String buildNestedUnion(String unionType, Seq sources, String pr @Override public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { - // Check for PruneEmptyIntersect pattern: intersect with empty right source if (intersect.sources().size() == 2) { RelRN leftSource = intersect.sources().get(0); RelRN rightSource = intersect.sources().get(1); @@ -316,7 +345,6 @@ private String buildNestedIntersect(String intersectType, Seq sources, S @Override public Env onMatchMinus(Env env, RelRN.Minus minus) { - // Handle MinusMerge: (Except (Except left rightB pInner) rightC pOuter) if (minus.sources().size() == 2 && minus.sources().get(0) instanceof RelRN.Minus inner) { String leftVar = env.generateVar("left"); Env leftEnv = env.addBinding("left", leftVar); @@ -339,7 +367,6 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { + ")"; return pOuterEnv.setPattern(pattern).focus(pattern); } - // Fallback generic formatting Env leftEnv = onMatch(env, minus.sources().get(0)); String leftPattern = leftEnv.current(); Env rightEnv = onMatch(leftEnv, minus.sources().get(1)); @@ -352,12 +379,10 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { @Override public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { - // Handle Aggregate over a nested LeftJoin (LeftJoin (LeftJoin left middle ...) right topOn topPrivate) if (aggregate.source() instanceof RelRN.Join topJoin && topJoin.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT && topJoin.left() instanceof RelRN.Join bottomJoin && bottomJoin.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT) { - // Bind variables String topJoinVar = env.generateVar("topJoin"); Env topEnv = env.addBinding("topJoin", topJoinVar); String bottomJoinVar = topEnv.generateVar("bottomJoin"); @@ -422,9 +447,7 @@ public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + ")"; return orderingEnv.setPattern(pattern).focus(pattern); } - // Special handling for Aggregate over LeftJoin to enable removing the join if (aggregate.source() instanceof RelRN.Join join && join.ty().semantics() == org.apache.calcite.rel.core.JoinRelType.LEFT) { - // Allocate and bind variables used across match and transform String inputVar = env.generateVar("input"); Env inputEnv = env.addBinding("input", inputVar); String leftVar = inputEnv.generateVar("left"); @@ -459,32 +482,30 @@ public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + ")"; return orderingEnv.setPattern(matchPattern).focus(matchPattern); } - // Check if source is a Project (for AggregateProjectMerge) if (aggregate.source() instanceof RelRN.Project project) { - Env innerInputEnv = onMatch(env, project.source()); - String innerInputPattern = innerInputEnv.current(); - Env projEnv = onMatch(innerInputEnv, project.map()); - String projPattern = projEnv.current(); - String passthroughVar = projEnv.generateVar("passthrough"); - Env passthroughEnv = projEnv.addBinding("passthrough", passthroughVar); + String inputVar = env.generateVar("input"); + Env inputEnv = env.addBinding("input", inputVar); + String projectionsVar = inputEnv.generateVar("projections"); + Env projectionsEnv = inputEnv.addBinding("projections", projectionsVar); + String passthroughVar = projectionsEnv.generateVar("passthrough"); + Env passthroughEnv = projectionsEnv.addBinding("passthrough", passthroughVar); - // Format as $input:(Project $innerInput:*) pattern (single line) - String inputVar = passthroughEnv.generateVar("input"); - Env inputEnv = passthroughEnv.addBinding("input", inputVar); - String projectPattern = "Project " + innerInputPattern; - String sourcePattern = "$" + inputVar + ":(" + projectPattern + ")"; + String aggregationsVar = passthroughEnv.generateVar("aggregations"); + Env aggregationsBindEnv = passthroughEnv.addBinding("aggregations", aggregationsVar); + String groupingPrivateVar = aggregationsBindEnv.generateVar("groupingPrivate"); + Env groupingPrivateBindEnv = aggregationsBindEnv.addBinding("groupingPrivate", groupingPrivateVar); - Env aggsEnv = onMatchAggCalls(inputEnv, aggregate.aggCalls()); - String aggsPattern = aggsEnv.current(); - Env groupingEnv = onMatchGroupSet(aggsEnv, aggregate.groupSet()); - String groupingPattern = groupingEnv.current(); - String privateVar = groupingEnv.generateVar("private"); - String innerInputVar = innerInputPattern.replace("$", "").replace(":*", ""); - Env privateEnv = groupingEnv.addBinding("aggregate_private", privateVar) - .addBinding("innerInput", innerInputVar); String aggregateType = determineAggregateType(aggregate); - String pattern = "(" + aggregateType + "\n " + sourcePattern + "\n " + aggsPattern + "\n $" + privateVar + ":*\n)"; - return privateEnv.setPattern(pattern).focus(pattern); + String pattern = "(" + aggregateType + "\n" + + " (Project\n" + + " $" + inputVar + ":*\n" + + " $" + projectionsVar + ":*\n" + + " $" + passthroughVar + ":*\n" + + " )\n" + + " $" + aggregationsVar + ":*\n" + + " $" + groupingPrivateVar + ":* & (CanRemapGroupingColsThroughProject $" + groupingPrivateVar + " $" + projectionsVar + " $" + passthroughVar + ")\n" + + ")"; + return groupingPrivateBindEnv.setPattern(pattern).focus(pattern); } Env sourceEnv = onMatch(env, aggregate.source()); @@ -497,10 +518,8 @@ public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { Env privateEnv = groupingEnv.addBinding("aggregate_private", privateVar); String aggregateType = determineAggregateType(aggregate); - // Check if this is an AggregateExtractProject pattern boolean hasProjectionExpressions = hasProjectionExpressionsInAggregate(aggregate); if (hasProjectionExpressions) { - // Generate numbered variables for input, aggregations, groupingPrivate String inputVar = privateEnv.generateVar("input"); Env inputEnv = privateEnv.addBinding("input", inputVar); String aggregationsVar = inputEnv.generateVar("aggregations"); @@ -520,11 +539,9 @@ private Env onMatchAggCalls(Env env, Seq aggCalls) { Seq aggPatterns = Seq.empty(); boolean hasProjOperand = false; for (RelRN.AggCall aggCall : aggCalls) { - // Check if aggCall has a Proj operand (for AggregateProjectMerge) if (aggCall.operands().size() == 1) { RexRN operand = aggCall.operands().get(0); if (operand instanceof RexRN.Proj proj) { - // Reference the proj variable if it exists String projVar = currentEnv.bindings().getOrDefault(proj.operator().getName(), null); if (projVar != null) { aggPatterns = aggPatterns.appended("$" + projVar + ":*"); @@ -540,7 +557,6 @@ private Env onMatchAggCalls(Env env, Seq aggCalls) { } String pattern; if (aggCalls.size() == 1 && hasProjOperand) { - // Use the proj reference directly pattern = aggPatterns.get(0); return currentEnv.setPattern(pattern).focus(pattern); } else if (aggCalls.size() == 1) { @@ -599,11 +615,9 @@ public Env onMatchProj(Env env, RexRN.Proj proj) { } public Env onMatchGroupBy(Env env, RexRN.GroupBy groupBy) { - // Check if GroupBy wraps a Proj expression (for AggregateProjectMerge) if (groupBy.sources().size() == 1) { RexRN innerExpr = groupBy.sources().get(0); if (innerExpr instanceof RexRN.Proj proj) { - // Bind the proj operator name to reference the proj variable String projVar = env.bindings().getOrDefault(proj.operator().getName(), null); if (projVar != null) { return env.focus("$" + projVar + ":*"); @@ -671,7 +685,6 @@ public Env transformScan(Env env, RelRN.Scan scan) { @Override public Env transformFilter(Env env, RelRN.Filter filter) { - // Check for PruneEmptyFilter pattern if (env.bindings().containsKey("isPruneEmptyFilter")) { String inputVar = env.bindings().get("pruneEmptyInput"); String pattern = "$" + inputVar; @@ -685,7 +698,6 @@ public Env transformFilter(Env env, RelRN.Filter filter) { return transform(env, filter.source()); } if (filter.cond() instanceof RexRN.False) { - // FilterReduceFalse: transform to ConstructEmptyValues Env sourceEnv = transform(env, filter.source()); String sourcePattern = sourceEnv.current(); String pattern = "(ConstructEmptyValues (OutputCols " + sourcePattern + "))"; @@ -709,7 +721,6 @@ public Env transformFilter(Env env, RelRN.Filter filter) { @Override public Env transformProject(Env env, RelRN.Project project) { - // If input is known to have zero rows, return the input reference if (env.bindings().containsKey("hasZeroRows")) { String inputVar = env.bindings().getOrDefault("zeroInput", "input"); String pattern = "$" + inputVar; @@ -730,7 +741,6 @@ public Env transformProject(Env env, RelRN.Project project) { @Override public Env transformUnion(Env env, RelRN.Union union) { - // If onMatch indicated zero rows, construct empty using the first zero-input's schema if (env.bindings().containsKey("hasZeroRows")) { String leftVar = env.bindings().getOrDefault("zeroInput", "input"); String pattern = "(ConstructEmptyValues (OutputCols $" + leftVar + "))"; @@ -767,7 +777,6 @@ private String buildNestedUnionTransform(String unionType, Seq sources, @Override public Env transformIntersect(Env env, RelRN.Intersect intersect) { - // Check for PruneEmptyIntersect pattern if (env.bindings().containsKey("isPruneEmptyIntersect")) { String leftVar = env.bindings().get("pruneEmptyLeft"); String pattern = "(ConstructEmptyValues (OutputCols $" + leftVar + "))"; @@ -812,7 +821,6 @@ private String buildNestedIntersectTransform(String intersectType, Seq s @Override public Env transformMinus(Env env, RelRN.Minus minus) { - // Transform for MinusMerge using generic base names; numbering will be applied in translate() String pattern = "(Except\n" + " $left\n" + " (Union\n" @@ -827,7 +835,6 @@ public Env transformMinus(Env env, RelRN.Minus minus) { @Override public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { - // Transform for nested LeftJoin removal if binding variables are available if (env.bindings().containsKey("left") && env.bindings().containsKey("right") && env.bindings().containsKey("topOn") && env.bindings().containsKey("topPrivate") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingCols") @@ -855,8 +862,6 @@ public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + ")"; return env.setPattern(pattern).focus(pattern); } - // Transform for Aggregate over LeftJoin: drop join and adjust grouping - // Use presence of match bindings (left, aggregations, groupingCols, ordering) to drive transform if (env.bindings().containsKey("left") && env.bindings().containsKey("aggregations") && env.bindings().containsKey("groupingCols") && env.bindings().containsKey("ordering")) { String leftVar = env.bindings().get("left"); @@ -874,22 +879,23 @@ public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + ")"; return env.setPattern(pattern).focus(pattern); } - // Check if we had a Project source (for AggregateProjectMerge) - String innerInput = env.bindings().getOrDefault("innerInput", null); - if (innerInput != null) { - // Use innerInput instead of transforming the Project - Env groupingEnv = transformGroupSet(env, aggregate.groupSet()); - Env aggsEnv = transformAggCalls(groupingEnv, aggregate.aggCalls()); - String aggsPattern = aggsEnv.current(); - String privateVar = aggsEnv.bindings().getOrDefault("aggregate_private", "private"); + if (env.bindings().containsKey("input") && env.bindings().containsKey("projections") + && env.bindings().containsKey("passthrough") && env.bindings().containsKey("aggregations") + && env.bindings().containsKey("groupingPrivate")) { + String inputVar = env.bindings().get("input"); + String projectionsVar = env.bindings().get("projections"); + String passthroughVar = env.bindings().get("passthrough"); + String aggregationsVar = env.bindings().get("aggregations"); + String groupingPrivateVar = env.bindings().get("groupingPrivate"); String aggregateType = determineAggregateType(aggregate); - // Use the actual operator name in the after transformation - String pattern = "(" + aggregateType + " $" + innerInput + " " + aggsPattern + " $" + privateVar + ")"; - return aggsEnv.setPattern(pattern).focus(pattern); + String pattern = "(" + aggregateType + "\n" + + " $" + inputVar + "\n" + + " (RemapAggregationsThroughProject $" + aggregationsVar + " $" + projectionsVar + ")\n" + + " (RemapGroupingColsThroughProject $" + groupingPrivateVar + " $" + projectionsVar + " $" + passthroughVar + ")\n" + + ")"; + return env.setPattern(pattern).focus(pattern); } - // Check if this is an AggregateExtractProject pattern - // This happens when the aggregate has projection expressions that need to be extracted if (env.bindings().containsKey("isAggregateExtractProject")) { Env sourceEnv = transform(env, aggregate.source()); String sourcePattern = sourceEnv.current(); @@ -915,14 +921,12 @@ public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { } private boolean hasProjectionExpressionsInAggregate(RelRN.Aggregate aggregate) { - // Check if any grouping expressions are projections for (RexRN groupExpr : aggregate.groupSet()) { if (groupExpr instanceof RexRN.Proj) { return true; } } - // Check if any aggregation expressions are projections for (RelRN.AggCall aggCall : aggregate.aggCalls()) { for (RexRN operand : aggCall.operands()) { if (operand instanceof RexRN.Proj) { @@ -966,17 +970,13 @@ private Env transformGroupSet(Env env, Seq groupSet) { @Override public Env transformEmpty(Env env, RelRN.Empty empty) { - // If upstream matched an operator with zero rows input if (env.bindings().containsKey("hasZeroRows")) { String inputVar = env.bindings().getOrDefault("zeroInput", "input"); - // Check if pattern contains Union indicators (has both left and right) String patternStr = env.pattern(); if (patternStr != null && patternStr.contains("Union") && (patternStr.contains("$left") || inputVar.startsWith("left"))) { - // For Union with zero rows, construct empty with left's schema String pattern = "(ConstructEmptyValues (OutputCols $" + inputVar + "))"; return env.setPattern(pattern).focus(pattern); } - // For other cases (like Project), just return the input String pattern = "$" + inputVar; return env.setPattern(pattern).focus(pattern); } @@ -1014,11 +1014,9 @@ public Env transformProj(Env env, RexRN.Proj proj) { } public Env transformGroupBy(Env env, RexRN.GroupBy groupBy) { - // Check if GroupBy wraps a Proj expression (for AggregateProjectMerge) if (groupBy.sources().size() == 1) { RexRN innerExpr = groupBy.sources().get(0); if (innerExpr instanceof RexRN.Proj proj) { - // Reference the proj variable String projVar = env.bindings().get(proj.operator().getName()); if (projVar != null) { return env.setPattern("$" + projVar).focus("$" + projVar); @@ -1080,13 +1078,17 @@ public String translate(String name, Env onMatch, Env transform) { StringBuilder sb = new StringBuilder(); sb.append("[").append(name).append(", Normalize]\n"); String match = onMatch.pattern(); - // Normalize Union with HasZeroRows patterns: remove private field and Values references + if (name.equals("PruneEmptyProject")) { + match = match.replaceAll("\\$projections_\\d+", java.util.regex.Matcher.quoteReplacement("$projections")); + match = match.replaceAll("\\$passthrough_\\d+", java.util.regex.Matcher.quoteReplacement("$passthrough")); + } + if (name.equals("PruneEmptyFilter")) { + match = match.replaceAll("\\$filters_\\d+", java.util.regex.Matcher.quoteReplacement("$filters")); + } if (match.contains("HasZeroRows") && (match.startsWith("(Union") || match.startsWith("(UnionAll"))) { - // Remove private field if present match = match.replaceAll("\\s+\\$private_\\d+:\\*\\s*\\)", "\n)"); match = match.replaceAll("\\s+\\$private_\\d+:\\*\\)", ")"); } - // Also handle Union patterns that haven't been normalized yet else if (match.startsWith("(Union\n")) { String[] lines = match.split("\n"); if (lines.length >= 3 && lines[1].contains(":(Values)") && lines[2].contains(":(Values)")) { @@ -1100,39 +1102,54 @@ else if (match.startsWith("(Union\n")) { sb.append(match).append("\n"); sb.append("=>\n"); String out = transform.pattern(); - // Map generic $input to the first source var found in match (excluding private) - if (out.equals("$input") || out.startsWith("(ConstructEmptyValues (OutputCols $input")) { + if (out.startsWith("(ConstructEmptyValues (OutputCols $")) { + int startIdx = "(ConstructEmptyValues (OutputCols $".length(); + int endIdx = startIdx; + while (endIdx < out.length() && (Character.isLetterOrDigit(out.charAt(endIdx)) || out.charAt(endIdx) == '_')) { + endIdx++; + } + String varInOutput = out.substring(startIdx, endIdx); + if (varInOutput.equals("input")) { + String numbered = findFirstVar(match); + if (numbered != null) { + out = out.replace("(ConstructEmptyValues (OutputCols $input)", + "(ConstructEmptyValues (OutputCols $" + numbered + ")"); + } + } + } else if (out.equals("$input")) { String numbered = findFirstVar(match); if (numbered != null) { - out = out.replace("$input_0", "$" + numbered).replace("$input", "$" + numbered); + out = "$" + numbered; } } - // If match has HasZeroRows pattern and output is ConstructEmptyValues, extract left variable from match if (match.contains("HasZeroRows") && match.contains("$left") && out.contains("ConstructEmptyValues")) { - // Extract the left variable name from the match pattern (e.g., $left_0) int leftIdx = match.indexOf("$left"); if (leftIdx >= 0) { - int start = leftIdx + 1; // skip $ + int start = leftIdx + 1; int end = start; while (end < match.length() && (Character.isLetterOrDigit(match.charAt(end)) || match.charAt(end) == '_')) end++; String leftVar = match.substring(start, end); - // Replace any variable in OutputCols with the actual left variable from match out = out.replaceAll("(OutputCols \\$)[a-zA-Z_][a-zA-Z0-9_]*", "$1" + leftVar); } } - // Align unnumbered variables in output to numbered variables from match java.util.Map varMap = extractNumberedVarMap(match); if (!varMap.isEmpty()) { for (java.util.Map.Entry e : varMap.entrySet()) { String base = e.getKey(); - String numbered = e.getValue(); // includes leading $ - // Replace standalone $base (not already numbered) in output + String numbered = e.getValue(); out = out.replaceAll( "\\$" + java.util.regex.Pattern.quote(base) + "(?![_0-9])", java.util.regex.Matcher.quoteReplacement(numbered) ); } } + if (name.equals("PruneEmptyProject")) { + out = out.replaceAll("\\$projections_\\d+", java.util.regex.Matcher.quoteReplacement("$projections")); + out = out.replaceAll("\\$passthrough_\\d+", java.util.regex.Matcher.quoteReplacement("$passthrough")); + } + if (name.equals("PruneEmptyFilter")) { + out = out.replaceAll("\\$filters_\\d+", java.util.regex.Matcher.quoteReplacement("$filters")); + } sb.append(out).append("\n"); return sb.toString(); } @@ -1169,7 +1186,6 @@ private static java.util.Map extractNumberedVarMap(String match) @Override public Env preTransform(Env env) { - // If the onMatch pattern signaled HasZeroRows, propagate a generic binding String p = env.pattern(); if (p != null) { int idx = p.indexOf("(HasZeroRows $"); diff --git a/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java b/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java index 184c58e..787262b 100644 --- a/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java +++ b/src/main/java/org/qed/Backends/Cockroach/CockroachTester.java @@ -58,8 +58,8 @@ public static Seq ruleList() { for (java.io.File file : files) { String className = file.getName().replace(".java", ""); if (className.contains("Distinct") || className.contains("Pull") || - className.contains("JoinConditionPush") || className.contains("ProjectAggregateMerge") || - className.contains("AggregativeJoinRemove") || className.contains("AggregateProjectConstantToDummyJoin") || className.contains("AggregateProjectMerge")) { + className.contains("ProjectAggregateMerge") || + className.contains("AggregativeJoinRemove") || className.contains("AggregateProjectConstantToDummyJoin")) { continue; } diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt new file mode 100644 index 0000000..0bfb8af --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/AggregateProjectMerge.opt @@ -0,0 +1,16 @@ +[AggregateProjectMerge, Normalize] +(GroupBy + (Project + $input_0:* + $projections_1:* + $passthrough_2:* + ) + $aggregations_3:* + $groupingPrivate_4:* & (CanRemapGroupingColsThroughProject $groupingPrivate_4 $projections_1 $passthrough_2) +) +=> +(GroupBy + $input_0 + (RemapAggregationsThroughProject $aggregations_3 $projections_1) + (RemapGroupingColsThroughProject $groupingPrivate_4 $projections_1 $passthrough_2) +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt index f0b0c87..4421882 100644 --- a/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/FilterReduceFalse.opt @@ -8,4 +8,4 @@ ] ) => -(ConstructEmptyValues (OutputCols $input_0_0)) +(ConstructEmptyValues (OutputCols $input_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt new file mode 100644 index 0000000..82bcdea --- /dev/null +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/JoinConditionPush.opt @@ -0,0 +1,17 @@ +[JoinConditionPush, Normalize] +(InnerJoin + $left_0:* & ^(HasOuterCols $left_0) + $right_1:* & ^(HasOuterCols $right_1) + $on_2:* + $private_3:* +) +=> +(InnerJoin + (Select $left_0 (ExtractBoundConditions $on_2 (OutputCols $left_0))) + (Select $right_1 (ExtractBoundConditions $on_2 (OutputCols $right_1))) + (ExtractUnboundConditions + (ExtractUnboundConditions $on_2 (OutputCols $left_0)) + (OutputCols $right_1) + ) + $private_3 +) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt index 9ebbded..c7a31cb 100644 --- a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyFilter.opt @@ -1,7 +1,7 @@ [PruneEmptyFilter, Normalize] (Select $input_2:* & (HasZeroRows $input_2) - $filters_3:* + $filters:* ) => $input_2 diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt index 0221bd1..285acce 100644 --- a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyMinus.opt @@ -5,4 +5,4 @@ $private_2:* ) => -(ConstructEmptyValues (OutputCols $empty_0)) +(ConstructEmptyValues (OutputCols $input_0)) diff --git a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt index c2d646e..3c263ac 100644 --- a/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt +++ b/src/main/java/org/qed/Backends/Cockroach/Generated/PruneEmptyProject.opt @@ -1,8 +1,8 @@ [PruneEmptyProject, Normalize] (Project $input_0:* & (HasZeroRows $input_0) - $projections_1:* - $passthrough_2:* + $projections:* + $passthrough:* ) => $input_0