Skip to content

Commit a2c9ec8

Browse files
committed
fix
1 parent 8566374 commit a2c9ec8

File tree

3 files changed

+144
-50
lines changed

3 files changed

+144
-50
lines changed

pgdog/src/frontend/router/rewrite/unique_id/explain.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use pg_query::NodeEnum;
44

55
use super::{
66
super::{Context, Error, RewriteModule},
7-
InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite,
7+
max_param_number, InsertUniqueIdRewrite, SelectUniqueIdRewrite, UpdateUniqueIdRewrite,
88
};
99

1010
#[derive(Default)]
@@ -81,7 +81,7 @@ impl ExplainUniqueIdRewrite {
8181

8282
let mut bind = input.bind_take();
8383
let extended = input.extended();
84-
let mut parameter_counter = 0;
84+
let mut parameter_counter = max_param_number(input.parse_result());
8585

8686
if let Some(NodeEnum::ExplainStmt(stmt)) = input
8787
.stmt_mut()?
@@ -130,6 +130,7 @@ impl ExplainUniqueIdRewrite {
130130

131131
let mut bind = input.bind_take();
132132
let extended = input.extended();
133+
let mut param_counter = max_param_number(input.parse_result());
133134

134135
if let Some(NodeEnum::ExplainStmt(stmt)) = input
135136
.stmt_mut()?
@@ -140,7 +141,12 @@ impl ExplainUniqueIdRewrite {
140141
if let Some(NodeEnum::InsertStmt(insert)) =
141142
stmt.query.as_mut().and_then(|q| q.node.as_mut())
142143
{
143-
InsertUniqueIdRewrite::rewrite_insert(insert, &mut bind, extended)?;
144+
InsertUniqueIdRewrite::rewrite_insert(
145+
insert,
146+
&mut bind,
147+
extended,
148+
&mut param_counter,
149+
)?;
144150
}
145151
}
146152

pgdog/src/frontend/router/rewrite/unique_id/insert.rs

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use pg_query::{protobuf::InsertStmt, NodeEnum};
22

33
use super::{
44
super::{Context, Error, RewriteModule},
5-
bigint_const, bigint_param,
5+
bigint_const, bigint_param, max_param_number,
66
};
77
use crate::{
88
frontend::router::parser::{Insert, Value},
@@ -34,8 +34,8 @@ impl InsertUniqueIdRewrite {
3434
stmt: &mut InsertStmt,
3535
bind: &mut Option<crate::net::Bind>,
3636
extended: bool,
37+
param_counter: &mut i32,
3738
) -> Result<(), Error> {
38-
let mut param_counter = Self::param_count(stmt)?;
3939
let select = stmt
4040
.select_stmt
4141
.as_mut()
@@ -53,15 +53,15 @@ impl InsertUniqueIdRewrite {
5353
let id = unique_id::UniqueId::generator()?.next_id();
5454

5555
let node = if extended {
56-
param_counter += 1;
56+
*param_counter += 1;
5757
if let Some(ref mut bind) = bind {
5858
let count = bind.add_parameter(Datum::Bigint(id))?;
5959
// The number of parameters in the query doesn't match what's in the bind message.
60-
if count != param_counter {
60+
if count != *param_counter {
6161
return Err(Error::ParameterCountMismatch);
6262
}
6363
}
64-
bigint_param(param_counter)
64+
bigint_param(*param_counter)
6565
} else {
6666
bigint_const(id)
6767
};
@@ -76,34 +76,6 @@ impl InsertUniqueIdRewrite {
7676

7777
Ok(())
7878
}
79-
80-
fn param_count(stmt: &InsertStmt) -> Result<i32, Error> {
81-
let mut max = 0;
82-
83-
let select = stmt
84-
.select_stmt
85-
.as_ref()
86-
.ok_or(Error::ParserError)?
87-
.node
88-
.as_ref()
89-
.ok_or(Error::ParserError)?;
90-
91-
if let NodeEnum::SelectStmt(stmt) = select {
92-
for tuple in stmt.values_lists.iter() {
93-
if let Some(NodeEnum::List(ref tuple)) = tuple.node {
94-
for column in tuple.items.iter() {
95-
if let Some(NodeEnum::ParamRef(ref param)) = column.node {
96-
if param.number > max {
97-
max = param.number;
98-
}
99-
}
100-
}
101-
}
102-
}
103-
}
104-
105-
Ok(max)
106-
}
10779
}
10880

10981
impl RewriteModule for InsertUniqueIdRewrite {
@@ -125,14 +97,15 @@ impl RewriteModule for InsertUniqueIdRewrite {
12597

12698
let mut bind = input.bind_take();
12799
let extended = input.extended();
100+
let mut param_counter = max_param_number(input.parse_result());
128101

129102
if let Some(NodeEnum::InsertStmt(stmt)) = input
130103
.stmt_mut()?
131104
.stmt
132105
.as_mut()
133106
.and_then(|stmt| stmt.node.as_mut())
134107
{
135-
Self::rewrite_insert(stmt, &mut bind, extended)?;
108+
Self::rewrite_insert(stmt, &mut bind, extended, &mut param_counter)?;
136109
}
137110

138111
input.bind_put(bind);

pgdog/src/frontend/router/rewrite/unique_id/mod.rs

Lines changed: 128 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use pg_query::{
44
protobuf::{a_const::Val, AConst, Node, ParamRef, ParseResult, TypeCast, TypeName},
5-
NodeEnum, NodeRef,
5+
NodeEnum,
66
};
77

88
pub mod explain;
@@ -77,16 +77,131 @@ fn bigint_const(id: i64) -> NodeEnum {
7777

7878
/// Find the maximum parameter number ($N) in a parse result.
7979
pub fn max_param_number(result: &ParseResult) -> i32 {
80-
result
81-
.nodes()
82-
.iter()
83-
.filter_map(|(node, _, _, _)| {
84-
if let NodeRef::ParamRef(p) = node {
85-
Some(p.number)
86-
} else {
87-
None
88-
}
89-
})
90-
.max()
91-
.unwrap_or(0)
80+
let mut max = 0;
81+
for stmt in &result.stmts {
82+
if let Some(ref stmt) = stmt.stmt {
83+
find_max_param(&stmt.node, &mut max);
84+
}
85+
}
86+
max
87+
}
88+
89+
fn find_max_param(node: &Option<NodeEnum>, max: &mut i32) {
90+
let Some(node) = node else {
91+
return;
92+
};
93+
94+
match node {
95+
NodeEnum::ParamRef(param) => {
96+
if param.number > *max {
97+
*max = param.number;
98+
}
99+
}
100+
NodeEnum::TypeCast(cast) => {
101+
if let Some(ref arg) = cast.arg {
102+
find_max_param(&arg.node, max);
103+
}
104+
}
105+
NodeEnum::FuncCall(func) => {
106+
for arg in &func.args {
107+
find_max_param(&arg.node, max);
108+
}
109+
}
110+
NodeEnum::AExpr(expr) => {
111+
if let Some(ref lexpr) = expr.lexpr {
112+
find_max_param(&lexpr.node, max);
113+
}
114+
if let Some(ref rexpr) = expr.rexpr {
115+
find_max_param(&rexpr.node, max);
116+
}
117+
}
118+
NodeEnum::SelectStmt(stmt) => {
119+
for item in &stmt.target_list {
120+
find_max_param(&item.node, max);
121+
}
122+
for item in &stmt.values_lists {
123+
find_max_param(&item.node, max);
124+
}
125+
for item in &stmt.from_clause {
126+
find_max_param(&item.node, max);
127+
}
128+
if let Some(ref clause) = stmt.where_clause {
129+
find_max_param(&clause.node, max);
130+
}
131+
if let Some(ref limit) = stmt.limit_count {
132+
find_max_param(&limit.node, max);
133+
}
134+
if let Some(ref offset) = stmt.limit_offset {
135+
find_max_param(&offset.node, max);
136+
}
137+
}
138+
NodeEnum::InsertStmt(stmt) => {
139+
if let Some(ref select) = stmt.select_stmt {
140+
find_max_param(&select.node, max);
141+
}
142+
}
143+
NodeEnum::UpdateStmt(stmt) => {
144+
for item in &stmt.target_list {
145+
find_max_param(&item.node, max);
146+
}
147+
if let Some(ref clause) = stmt.where_clause {
148+
find_max_param(&clause.node, max);
149+
}
150+
}
151+
NodeEnum::DeleteStmt(stmt) => {
152+
if let Some(ref clause) = stmt.where_clause {
153+
find_max_param(&clause.node, max);
154+
}
155+
}
156+
NodeEnum::ResTarget(res) => {
157+
if let Some(ref val) = res.val {
158+
find_max_param(&val.node, max);
159+
}
160+
}
161+
NodeEnum::List(list) => {
162+
for item in &list.items {
163+
find_max_param(&item.node, max);
164+
}
165+
}
166+
NodeEnum::CoalesceExpr(coalesce) => {
167+
for arg in &coalesce.args {
168+
find_max_param(&arg.node, max);
169+
}
170+
}
171+
NodeEnum::CaseExpr(case) => {
172+
if let Some(ref arg) = case.arg {
173+
find_max_param(&arg.node, max);
174+
}
175+
for when in &case.args {
176+
find_max_param(&when.node, max);
177+
}
178+
if let Some(ref defresult) = case.defresult {
179+
find_max_param(&defresult.node, max);
180+
}
181+
}
182+
NodeEnum::CaseWhen(when) => {
183+
if let Some(ref expr) = when.expr {
184+
find_max_param(&expr.node, max);
185+
}
186+
if let Some(ref result) = when.result {
187+
find_max_param(&result.node, max);
188+
}
189+
}
190+
NodeEnum::BoolExpr(expr) => {
191+
for arg in &expr.args {
192+
find_max_param(&arg.node, max);
193+
}
194+
}
195+
NodeEnum::NullTest(test) => {
196+
if let Some(ref arg) = test.arg {
197+
find_max_param(&arg.node, max);
198+
}
199+
}
200+
NodeEnum::ExplainStmt(stmt) => {
201+
if let Some(ref query) = stmt.query {
202+
find_max_param(&query.node, max);
203+
}
204+
}
205+
_ => {}
206+
}
92207
}

0 commit comments

Comments
 (0)