@@ -86,6 +86,7 @@ public JoinInfo dealJoinNode(SqlJoin joinNode,
8686 Queue <Object > queueInfo ,
8787 SqlNode parentWhere ,
8888 SqlNodeList parentSelectList ,
89+ SqlNodeList parentGroupByList ,
8990 Set <Tuple2 <String , String >> joinFieldSet ,
9091 Map <String , String > tableRef ,
9192 Map <String , String > fieldRef ) {
@@ -105,20 +106,20 @@ public JoinInfo dealJoinNode(SqlJoin joinNode,
105106 if (leftNode .getKind () == JOIN ) {
106107 //处理连续join
107108 dealNestJoin (joinNode , sideTableSet ,
108- queueInfo , parentWhere , parentSelectList , joinFieldSet , tableRef , fieldRef , parentSelectList );
109+ queueInfo , parentWhere , parentSelectList , parentGroupByList , joinFieldSet , tableRef , fieldRef );
109110 leftNode = joinNode .getLeft ();
110111 }
111112
112113 if (leftNode .getKind () == AS ) {
113- AliasInfo aliasInfo = (AliasInfo ) sideSQLParser .parseSql (leftNode , sideTableSet , queueInfo , parentWhere , parentSelectList );
114+ AliasInfo aliasInfo = (AliasInfo ) sideSQLParser .parseSql (leftNode , sideTableSet , queueInfo , parentWhere , parentSelectList , parentGroupByList );
114115 leftTbName = aliasInfo .getName ();
115116 leftTbAlias = aliasInfo .getAlias ();
116117 }
117118
118119 boolean leftIsSide = checkIsSideTable (leftTbName , sideTableSet );
119120 Preconditions .checkState (!leftIsSide , "side-table must be at the right of join operator" );
120121
121- Tuple2 <String , String > rightTableNameAndAlias = parseRightNode (rightNode , sideTableSet , queueInfo , parentWhere , parentSelectList );
122+ Tuple2 <String , String > rightTableNameAndAlias = parseRightNode (rightNode , sideTableSet , queueInfo , parentWhere , parentSelectList , parentGroupByList );
122123 rightTableName = rightTableNameAndAlias .f0 ;
123124 rightTableAlias = rightTableNameAndAlias .f1 ;
124125
@@ -145,7 +146,7 @@ public JoinInfo dealJoinNode(SqlJoin joinNode,
145146
146147 //extract 需要查询的字段信息
147148 if (rightIsSide ){
148- extractJoinNeedSelectField (leftNode , rightNode , parentWhere , parentSelectList , tableRef , joinFieldSet , fieldRef , tableInfo );
149+ extractJoinNeedSelectField (leftNode , rightNode , parentWhere , parentSelectList , parentGroupByList , tableRef , joinFieldSet , fieldRef , tableInfo );
149150 }
150151
151152 if (tableInfo .getLeftNode ().getKind () != AS ){
@@ -168,13 +169,14 @@ public void extractJoinNeedSelectField(SqlNode leftNode,
168169 SqlNode rightNode ,
169170 SqlNode parentWhere ,
170171 SqlNodeList parentSelectList ,
172+ SqlNodeList parentGroupByList ,
171173 Map <String , String > tableRef ,
172174 Set <Tuple2 <String , String >> joinFieldSet ,
173175 Map <String , String > fieldRef ,
174176 JoinInfo tableInfo ){
175177
176- Set <String > extractSelectField = extractField (leftNode , parentWhere , parentSelectList , tableRef , joinFieldSet );
177- Set <String > rightExtractSelectField = extractField (rightNode , parentWhere , parentSelectList , tableRef , joinFieldSet );
178+ Set <String > extractSelectField = extractField (leftNode , parentWhere , parentSelectList , parentGroupByList , tableRef , joinFieldSet );
179+ Set <String > rightExtractSelectField = extractField (rightNode , parentWhere , parentSelectList , parentGroupByList , tableRef , joinFieldSet );
178180
179181 //重命名right 中和 left 重名的
180182 Map <String , String > leftTbSelectField = Maps .newHashMap ();
@@ -208,13 +210,15 @@ public void extractJoinNeedSelectField(SqlNode leftNode,
208210 * @param sqlNode
209211 * @param parentWhere
210212 * @param parentSelectList
213+ * @param parentGroupByList
211214 * @param tableRef
212215 * @param joinFieldSet
213216 * @return
214217 */
215218 public Set <String > extractField (SqlNode sqlNode ,
216219 SqlNode parentWhere ,
217220 SqlNodeList parentSelectList ,
221+ SqlNodeList parentGroupByList ,
218222 Map <String , String > tableRef ,
219223 Set <Tuple2 <String , String >> joinFieldSet ){
220224 Set <String > fromTableNameSet = Sets .newHashSet ();
@@ -225,8 +229,11 @@ public Set<String> extractField(SqlNode sqlNode,
225229 Set <String > extractSelectField = extractSelectFields (parentSelectList , fromTableNameSet , tableRef );
226230 Set <String > fieldFromJoinCondition = extractSelectFieldFromJoinCondition (joinFieldSet , fromTableNameSet , tableRef );
227231
232+ Set <String > extractGroupByField = extractFieldFromGroupByList (parentGroupByList , fromTableNameSet , tableRef );
233+
228234 extractSelectField .addAll (extractCondition );
229235 extractSelectField .addAll (fieldFromJoinCondition );
236+ extractSelectField .addAll (extractGroupByField );
230237
231238 return extractSelectField ;
232239 }
@@ -242,27 +249,27 @@ private JoinInfo dealNestJoin(SqlJoin joinNode,
242249 Set <String > sideTableSet ,
243250 Queue <Object > queueInfo ,
244251 SqlNode parentWhere ,
245- SqlNodeList selectList ,
252+ SqlNodeList parentSelectList ,
253+ SqlNodeList parentGroupByList ,
246254 Set <Tuple2 <String , String >> joinFieldSet ,
247255 Map <String , String > tableRef ,
248- Map <String , String > fieldRef ,
249- SqlNodeList parentSelectList ){
256+ Map <String , String > fieldRef ){
250257
251258 SqlJoin leftJoinNode = (SqlJoin ) joinNode .getLeft ();
252259 SqlNode parentRightJoinNode = joinNode .getRight ();
253260 SqlNode rightNode = leftJoinNode .getRight ();
254- Tuple2 <String , String > rightTableNameAndAlias = parseRightNode (rightNode , sideTableSet , queueInfo , parentWhere , selectList );
255- Tuple2 <String , String > parentRightJoinInfo = parseRightNode (parentRightJoinNode , sideTableSet , queueInfo , parentWhere , selectList );
261+ Tuple2 <String , String > rightTableNameAndAlias = parseRightNode (rightNode , sideTableSet , queueInfo , parentWhere , parentSelectList , parentGroupByList );
262+ Tuple2 <String , String > parentRightJoinInfo = parseRightNode (parentRightJoinNode , sideTableSet , queueInfo , parentWhere , parentSelectList , parentGroupByList );
256263 boolean parentRightIsSide = checkIsSideTable (parentRightJoinInfo .f0 , sideTableSet );
257264
258- JoinInfo joinInfo = dealJoinNode (leftJoinNode , sideTableSet , queueInfo , parentWhere , selectList , joinFieldSet , tableRef , fieldRef );
265+ JoinInfo joinInfo = dealJoinNode (leftJoinNode , sideTableSet , queueInfo , parentWhere , parentSelectList , parentGroupByList , joinFieldSet , tableRef , fieldRef );
259266
260267 String rightTableName = rightTableNameAndAlias .f0 ;
261268 boolean rightIsSide = checkIsSideTable (rightTableName , sideTableSet );
262269 SqlBasicCall buildAs = TableUtils .buildAsNodeByJoinInfo (joinInfo , null , null );
263270
264271 if (rightIsSide ){
265- addSideInfoToExeQueue (queueInfo , joinInfo , joinNode , parentSelectList , parentWhere , tableRef );
272+ addSideInfoToExeQueue (queueInfo , joinInfo , joinNode , parentSelectList , parentGroupByList , parentWhere , tableRef );
266273 }
267274
268275 SqlNode newLeftNode = joinNode .getLeft ();
@@ -275,7 +282,7 @@ private JoinInfo dealNestJoin(SqlJoin joinNode,
275282
276283 //替换leftNode 为新的查询
277284 joinNode .setLeft (buildAs );
278- replaceSelectAndWhereField (buildAs , leftJoinNode , tableRef , parentSelectList , parentWhere );
285+ replaceSelectAndWhereField (buildAs , leftJoinNode , tableRef , parentSelectList , parentGroupByList , parentWhere );
279286 }
280287
281288 return joinInfo ;
@@ -288,13 +295,15 @@ private JoinInfo dealNestJoin(SqlJoin joinNode,
288295 * @param joinInfo
289296 * @param joinNode
290297 * @param parentSelectList
298+ * @param parentGroupByList
291299 * @param parentWhere
292300 * @param tableRef
293301 */
294302 public void addSideInfoToExeQueue (Queue <Object > queueInfo ,
295303 JoinInfo joinInfo ,
296304 SqlJoin joinNode ,
297305 SqlNodeList parentSelectList ,
306+ SqlNodeList parentGroupByList ,
298307 SqlNode parentWhere ,
299308 Map <String , String > tableRef ){
300309 //只处理维表
@@ -308,7 +317,7 @@ public void addSideInfoToExeQueue(Queue<Object> queueInfo,
308317 //替换左表为新的表名称
309318 joinNode .setLeft (buildAs );
310319
311- replaceSelectAndWhereField (buildAs , leftJoinNode , tableRef , parentSelectList , parentWhere );
320+ replaceSelectAndWhereField (buildAs , leftJoinNode , tableRef , parentSelectList , parentGroupByList , parentWhere );
312321 }
313322
314323 /**
@@ -317,12 +326,14 @@ public void addSideInfoToExeQueue(Queue<Object> queueInfo,
317326 * @param leftJoinNode
318327 * @param tableRef
319328 * @param parentSelectList
329+ * @param parentGroupByList
320330 * @param parentWhere
321331 */
322332 public void replaceSelectAndWhereField (SqlBasicCall buildAs ,
323333 SqlNode leftJoinNode ,
324334 Map <String , String > tableRef ,
325335 SqlNodeList parentSelectList ,
336+ SqlNodeList parentGroupByList ,
326337 SqlNode parentWhere ){
327338
328339 String newLeftTableName = buildAs .getOperands ()[1 ].toString ();
@@ -341,10 +352,20 @@ public void replaceSelectAndWhereField(SqlBasicCall buildAs,
341352 }
342353 }
343354
355+ //TODO 应该根据上面的查询字段的关联关系来替换
344356 //替换where 中的条件相关
345357 for (String tbTmp : fromTableNameSet ){
346- TableUtils .replaceWhereCondition (parentWhere , tbTmp , newLeftTableName );
358+ TableUtils .replaceWhereCondition (parentWhere , tbTmp , newLeftTableName , fieldReplaceRef );
347359 }
360+
361+ if (parentGroupByList != null ){
362+ for (SqlNode sqlNode : parentGroupByList .getList ()){
363+ for (String tbTmp : fromTableNameSet ) {
364+ TableUtils .replaceSelectFieldTable (sqlNode , tbTmp , newLeftTableName , fieldReplaceRef );
365+ }
366+ }
367+ }
368+
348369 }
349370
350371 /**
@@ -407,7 +428,7 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias,
407428
408429 //替换where 中的条件相关
409430 for (String tbTmp : fromTableNameSet ){
410- TableUtils .replaceWhereCondition (parentWhere , tbTmp , tableAlias );
431+ TableUtils .replaceWhereCondition (parentWhere , tbTmp , tableAlias , fieldReplaceRef );
411432 }
412433
413434 for (String tbTmp : fromTableNameSet ){
@@ -426,7 +447,6 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias,
426447
427448 /**
428449 * 抽取上层需用使用到的字段
429- * 由于where字段已经抽取到上一层了所以不用查询出来
430450 * @param parentSelectList
431451 * @param fromTableNameSet
432452 * @return
@@ -451,7 +471,6 @@ private Set<String> extractSelectFieldFromJoinCondition(Set<Tuple2<String, Strin
451471 extractFieldList .add (field .f0 + "." + field .f1 );
452472 }
453473
454- //TODO
455474 if (tableRef .containsKey (field .f0 )){
456475 if (fromTableNameSet .contains (tableRef .get (field .f0 ))){
457476 extractFieldList .add (tableRef .get (field .f0 ) + "." + field .f1 );
@@ -462,6 +481,22 @@ private Set<String> extractSelectFieldFromJoinCondition(Set<Tuple2<String, Strin
462481 return extractFieldList ;
463482 }
464483
484+ private Set <String > extractFieldFromGroupByList (SqlNodeList parentGroupByList ,
485+ Set <String > fromTableNameSet ,
486+ Map <String , String > tableRef ){
487+
488+ if (parentGroupByList == null ){
489+ return Sets .newHashSet ();
490+ }
491+
492+ Set <String > extractFieldList = Sets .newHashSet ();
493+ for (SqlNode selectNode : parentGroupByList .getList ()){
494+ extractSelectField (selectNode , extractFieldList , fromTableNameSet , tableRef );
495+ }
496+
497+ return extractFieldList ;
498+ }
499+
465500 /**
466501 * 从join的条件中获取字段信息
467502 * @param condition
@@ -573,12 +608,12 @@ private void extractSelectField(SqlNode selectNode,
573608
574609
575610 private Tuple2 <String , String > parseRightNode (SqlNode sqlNode , Set <String > sideTableSet , Queue <Object > queueInfo ,
576- SqlNode parentWhere , SqlNodeList selectList ) {
611+ SqlNode parentWhere , SqlNodeList selectList , SqlNodeList parentGroupByList ) {
577612 Tuple2 <String , String > tabName = new Tuple2 <>("" , "" );
578613 if (sqlNode .getKind () == IDENTIFIER ){
579614 tabName .f0 = sqlNode .toString ();
580615 }else {
581- AliasInfo aliasInfo = (AliasInfo )sideSQLParser .parseSql (sqlNode , sideTableSet , queueInfo , parentWhere , selectList );
616+ AliasInfo aliasInfo = (AliasInfo )sideSQLParser .parseSql (sqlNode , sideTableSet , queueInfo , parentWhere , selectList , parentGroupByList );
582617 tabName .f0 = aliasInfo .getName ();
583618 tabName .f1 = aliasInfo .getAlias ();
584619 }
0 commit comments