Skip to content

Commit 0f5afd3

Browse files
committed
support rewriting numbered params to positional params
1 parent f44fe82 commit 0f5afd3

File tree

3 files changed

+62
-19
lines changed

3 files changed

+62
-19
lines changed

internal/dinosql/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ type PackageSettings struct {
5959
EmitJSONTags bool `json:"emit_json_tags"`
6060
EmitPreparedQueries bool `json:"emit_prepared_queries"`
6161
Overrides []Override `json:"overrides"`
62+
// HACK: this is only set in tests, only here till Kotlin support can be merged.
63+
rewriteParams bool
6264
}
6365

6466
type Override struct {

internal/dinosql/parser.go

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) {
229229
continue
230230
}
231231
for _, stmt := range tree.Statements {
232-
query, err := parseQuery(c, stmt, source)
232+
rewriteParameters := pkg.rewriteParams
233+
query, err := parseQuery(c, stmt, source, rewriteParameters)
233234
if err == errUnsupportedStatementType {
234235
continue
235236
}
@@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error {
407408

408409
var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")
409410

410-
func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) {
411+
func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) {
411412
if err := validateParamRef(stmt); err != nil {
412413
return nil, err
413414
}
@@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
443444
}
444445
rvs := rangeVars(raw.Stmt)
445446
refs := findParameters(raw.Stmt)
447+
var edits []edit
448+
if rewriteParameters {
449+
edits, err = rewriteNumberedParameters(refs, raw, rawSQL)
450+
if err != nil {
451+
return nil, err
452+
}
453+
} else {
454+
refs = uniqueParamRefs(refs)
455+
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
456+
}
446457
params, err := resolveCatalogRefs(c, rvs, refs)
447458
if err != nil {
448459
return nil, err
@@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
452463
if err != nil {
453464
return nil, err
454465
}
455-
expanded, err := expand(c, raw, rawSQL)
466+
expandEdits, err := expand(c, raw, rawSQL)
467+
if err != nil {
468+
return nil, err
469+
}
470+
edits = append(edits, expandEdits...)
471+
472+
expanded, err := editQuery(rawSQL, edits)
456473
if err != nil {
457474
return nil, err
458475
}
@@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
472489
}, nil
473490
}
474491

492+
func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) {
493+
edits := make([]edit, len(refs))
494+
for i, ref := range refs {
495+
edits[i] = edit{
496+
Location: ref.ref.Location - raw.StmtLocation,
497+
Old: fmt.Sprintf("$%d", ref.ref.Number),
498+
New: "?",
499+
}
500+
}
501+
return edits, nil
502+
}
503+
475504
func stripComments(sql string) (string, []string, error) {
476505
s := bufio.NewScanner(strings.NewReader(sql))
477506
var lines, comments []string
@@ -494,7 +523,7 @@ type edit struct {
494523
New string
495524
}
496525

497-
func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
526+
func expand(c core.Catalog, raw nodes.RawStmt, sql string) ([]edit, error) {
498527
list := search(raw, func(node nodes.Node) bool {
499528
switch node.(type) {
500529
case nodes.DeleteStmt:
@@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
507536
return true
508537
})
509538
if len(list.Items) == 0 {
510-
return sql, nil
539+
return nil, nil
511540
}
512541
var edits []edit
513542
for _, item := range list.Items {
514543
edit, err := expandStmt(c, raw, item)
515544
if err != nil {
516-
return "", err
545+
return nil, err
517546
}
518547
edits = append(edits, edit...)
519548
}
520-
return editQuery(sql, edits)
549+
return edits, nil
521550
}
522551

523552
func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
@@ -958,7 +987,8 @@ type paramRef struct {
958987
type paramSearch struct {
959988
parent nodes.Node
960989
rangeVar *nodes.RangeVar
961-
refs map[int]paramRef
990+
refs *[]paramRef
991+
seen map[int]struct{}
962992

963993
// XXX: Gross state hack for limit
964994
limitCount nodes.Node
@@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10051035
continue
10061036
}
10071037
// TODO: Out-of-bounds panic
1008-
p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}
1038+
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar})
1039+
p.seen[ref.Location] = struct{}{}
10091040
}
10101041
for _, vl := range s.ValuesLists {
10111042
for i, v := range vl {
@@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10141045
continue
10151046
}
10161047
// TODO: Out-of-bounds panic
1017-
p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}
1048+
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar})
1049+
p.seen[ref.Location] = struct{}{}
10181050
}
10191051
}
10201052
}
@@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10501082
parent = limitOffset{}
10511083
}
10521084
}
1053-
if _, found := p.refs[n.Number]; found {
1085+
if _, found := p.seen[n.Location]; found {
10541086
break
10551087
}
10561088

@@ -1072,21 +1104,18 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10721104
}
10731105

10741106
if set {
1075-
p.refs[n.Number] = paramRef{parent: parent, ref: n, rv: p.rangeVar}
1107+
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
1108+
p.seen[n.Location] = struct{}{}
10761109
}
10771110
return nil
10781111
}
10791112
return p
10801113
}
10811114

10821115
func findParameters(root nodes.Node) []paramRef {
1083-
v := paramSearch{refs: map[int]paramRef{}}
1084-
Walk(v, root)
10851116
refs := make([]paramRef, 0)
1086-
for _, r := range v.refs {
1087-
refs = append(refs, r)
1088-
}
1089-
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
1117+
v := paramSearch{seen: make(map[int]struct{}), refs: &refs}
1118+
Walk(v, root)
10901119
return refs
10911120
}
10921121

@@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13481377
}
13491378
return a, nil
13501379
}
1380+
1381+
func uniqueParamRefs(in []paramRef) []paramRef {
1382+
m := make(map[int]struct{}, len(in))
1383+
o := make([]paramRef, 0, len(in))
1384+
for _, v := range in {
1385+
if _, ok := m[v.ref.Number]; !ok {
1386+
m[v.ref.Number] = struct{}{}
1387+
o = append(o, v)
1388+
}
1389+
}
1390+
return o
1391+
}

internal/dinosql/query_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func parseSQL(in string) (Query, error) {
2121
return Query{}, err
2222
}
2323

24-
q, err := parseQuery(c, tree.Statements[len(tree.Statements)-1], in)
24+
q, err := parseQuery(c, tree.Statements[len(tree.Statements)-1], in, false)
2525
if q == nil {
2626
return Query{}, err
2727
}

0 commit comments

Comments
 (0)