@@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) {
229
229
continue
230
230
}
231
231
for _ , stmt := range tree .Statements {
232
- query , err := parseQuery (c , stmt , source )
232
+ rewriteParameters := pkg .rewriteParams
233
+ query , err := parseQuery (c , stmt , source , rewriteParameters )
233
234
if err == errUnsupportedStatementType {
234
235
continue
235
236
}
@@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error {
407
408
408
409
var errUnsupportedStatementType = errors .New ("parseQuery: unsupported statement type" )
409
410
410
- func parseQuery (c core.Catalog , stmt nodes.Node , source string ) (* Query , error ) {
411
+ func parseQuery (c core.Catalog , stmt nodes.Node , source string , rewriteParameters bool ) (* Query , error ) {
411
412
if err := validateParamRef (stmt ); err != nil {
412
413
return nil , err
413
414
}
@@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
443
444
}
444
445
rvs := rangeVars (raw .Stmt )
445
446
refs := findParameters (raw .Stmt )
447
+ var edits []edit
448
+ if rewriteParameters {
449
+ edits , err = rewriteNumberedParameters (refs , raw , rawSQL )
450
+ if err != nil {
451
+ return nil , err
452
+ }
453
+ } else {
454
+ refs = uniqueParamRefs (refs )
455
+ sort .Slice (refs , func (i , j int ) bool { return refs [i ].ref .Number < refs [j ].ref .Number })
456
+ }
446
457
params , err := resolveCatalogRefs (c , rvs , refs )
447
458
if err != nil {
448
459
return nil , err
@@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
452
463
if err != nil {
453
464
return nil , err
454
465
}
455
- expanded , err := expand (c , raw , rawSQL )
466
+ expandEdits , err := expand (c , raw , rawSQL )
467
+ if err != nil {
468
+ return nil , err
469
+ }
470
+ edits = append (edits , expandEdits ... )
471
+
472
+ expanded , err := editQuery (rawSQL , edits )
456
473
if err != nil {
457
474
return nil , err
458
475
}
@@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
472
489
}, nil
473
490
}
474
491
492
+ func rewriteNumberedParameters (refs []paramRef , raw nodes.RawStmt , sql string ) ([]edit , error ) {
493
+ edits := make ([]edit , len (refs ))
494
+ for i , ref := range refs {
495
+ edits [i ] = edit {
496
+ Location : ref .ref .Location - raw .StmtLocation ,
497
+ Old : fmt .Sprintf ("$%d" , ref .ref .Number ),
498
+ New : "?" ,
499
+ }
500
+ }
501
+ return edits , nil
502
+ }
503
+
475
504
func stripComments (sql string ) (string , []string , error ) {
476
505
s := bufio .NewScanner (strings .NewReader (sql ))
477
506
var lines , comments []string
@@ -494,7 +523,7 @@ type edit struct {
494
523
New string
495
524
}
496
525
497
- func expand (c core.Catalog , raw nodes.RawStmt , sql string ) (string , error ) {
526
+ func expand (c core.Catalog , raw nodes.RawStmt , sql string ) ([] edit , error ) {
498
527
list := search (raw , func (node nodes.Node ) bool {
499
528
switch node .(type ) {
500
529
case nodes.DeleteStmt :
@@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
507
536
return true
508
537
})
509
538
if len (list .Items ) == 0 {
510
- return sql , nil
539
+ return nil , nil
511
540
}
512
541
var edits []edit
513
542
for _ , item := range list .Items {
514
543
edit , err := expandStmt (c , raw , item )
515
544
if err != nil {
516
- return "" , err
545
+ return nil , err
517
546
}
518
547
edits = append (edits , edit ... )
519
548
}
520
- return editQuery ( sql , edits )
549
+ return edits , nil
521
550
}
522
551
523
552
func expandStmt (c core.Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
@@ -958,7 +987,8 @@ type paramRef struct {
958
987
type paramSearch struct {
959
988
parent nodes.Node
960
989
rangeVar * nodes.RangeVar
961
- refs map [int ]paramRef
990
+ refs * []paramRef
991
+ seen map [int ]struct {}
962
992
963
993
// XXX: Gross state hack for limit
964
994
limitCount nodes.Node
@@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
1005
1035
continue
1006
1036
}
1007
1037
// TODO: Out-of-bounds panic
1008
- p .refs [ref .Number ] = paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar }
1038
+ * p .refs = append (* p .refs , paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar })
1039
+ p .seen [ref .Location ] = struct {}{}
1009
1040
}
1010
1041
for _ , vl := range s .ValuesLists {
1011
1042
for i , v := range vl {
@@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
1014
1045
continue
1015
1046
}
1016
1047
// TODO: Out-of-bounds panic
1017
- p .refs [ref .Number ] = paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar }
1048
+ * p .refs = append (* p .refs , paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar })
1049
+ p .seen [ref .Location ] = struct {}{}
1018
1050
}
1019
1051
}
1020
1052
}
@@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
1050
1082
parent = limitOffset {}
1051
1083
}
1052
1084
}
1053
- if _ , found := p .refs [n .Number ]; found {
1085
+ if _ , found := p .seen [n .Location ]; found {
1054
1086
break
1055
1087
}
1056
1088
@@ -1072,21 +1104,18 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
1072
1104
}
1073
1105
1074
1106
if set {
1075
- p .refs [n .Number ] = paramRef {parent : parent , ref : n , rv : p .rangeVar }
1107
+ * p .refs = append (* p .refs , paramRef {parent : parent , ref : n , rv : p .rangeVar })
1108
+ p .seen [n .Location ] = struct {}{}
1076
1109
}
1077
1110
return nil
1078
1111
}
1079
1112
return p
1080
1113
}
1081
1114
1082
1115
func findParameters (root nodes.Node ) []paramRef {
1083
- v := paramSearch {refs : map [int ]paramRef {}}
1084
- Walk (v , root )
1085
1116
refs := make ([]paramRef , 0 )
1086
- for _ , r := range v .refs {
1087
- refs = append (refs , r )
1088
- }
1089
- sort .Slice (refs , func (i , j int ) bool { return refs [i ].ref .Number < refs [j ].ref .Number })
1117
+ v := paramSearch {seen : make (map [int ]struct {}), refs : & refs }
1118
+ Walk (v , root )
1090
1119
return refs
1091
1120
}
1092
1121
@@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
1348
1377
}
1349
1378
return a , nil
1350
1379
}
1380
+
1381
+ func uniqueParamRefs (in []paramRef ) []paramRef {
1382
+ m := make (map [int ]struct {}, len (in ))
1383
+ o := make ([]paramRef , 0 , len (in ))
1384
+ for _ , v := range in {
1385
+ if _ , ok := m [v .ref .Number ]; ! ok {
1386
+ m [v .ref .Number ] = struct {}{}
1387
+ o = append (o , v )
1388
+ }
1389
+ }
1390
+ return o
1391
+ }
0 commit comments