diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index e81348a596..19deb8bddc 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -64,6 +64,20 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { targets = n.ReturningList case *ast.SelectStmt: targets = n.TargetList + + if n.GroupClause != nil { + for _, item := range n.GroupClause.Items { + ref, ok := item.(*ast.ColumnRef) + if !ok { + continue + } + + if err := findColumnForRef(ref, tables); err != nil { + return nil, err + } + } + } + // For UNION queries, targets is empty and we need to look for the // columns in Largs. if len(targets.Items) == 0 && n.Larg != nil { @@ -470,3 +484,43 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef) } return cols, nil } + +func findColumnForRef(ref *ast.ColumnRef, tables []*Table) error { + parts := stringSlice(ref.Fields) + var alias, name string + if len(parts) == 1 { + name = parts[0] + } else if len(parts) == 2 { + alias = parts[0] + name = parts[1] + } + + var found int + for _, t := range tables { + if alias != "" && t.Rel.Name != alias { + continue + } + for _, c := range t.Columns { + if c.Name == name { + found++ + } + } + } + + if found == 0 { + return &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("column reference \"%s\" not found", name), + Location: ref.Location, + } + } + if found > 1 { + return &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("column reference \"%s\" is ambiguous", name), + Location: ref.Location, + } + } + + return nil +} diff --git a/internal/endtoend/testdata/invalid_group_by_reference/mysql/query.sql b/internal/endtoend/testdata/invalid_group_by_reference/mysql/query.sql new file mode 100644 index 0000000000..75a8e210fd --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/mysql/query.sql @@ -0,0 +1,11 @@ +CREATE TABLE authors ( + id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name text NOT NULL, + bio text, + UNIQUE(name) +); + +-- name: ListAuthors :many +SELECT * +FROM authors +GROUP BY invalid_reference; diff --git a/internal/endtoend/testdata/invalid_group_by_reference/mysql/sqlc.json b/internal/endtoend/testdata/invalid_group_by_reference/mysql/sqlc.json new file mode 100644 index 0000000000..534b7e24e9 --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/invalid_group_by_reference/mysql/stderr.txt b/internal/endtoend/testdata/invalid_group_by_reference/mysql/stderr.txt new file mode 100644 index 0000000000..18a7d6bd0f --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/mysql/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:9:1: column reference "invalid_reference" not found diff --git a/internal/endtoend/testdata/invalid_group_by_reference/postgresql/query.sql b/internal/endtoend/testdata/invalid_group_by_reference/postgresql/query.sql new file mode 100644 index 0000000000..fdfa7e4e05 --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/postgresql/query.sql @@ -0,0 +1,10 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + bio text +); + +-- name: ListAuthors :many +SELECT * +FROM authors +GROUP BY invalid_reference; diff --git a/internal/endtoend/testdata/invalid_group_by_reference/postgresql/sqlc.json b/internal/endtoend/testdata/invalid_group_by_reference/postgresql/sqlc.json new file mode 100644 index 0000000000..af57681f66 --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/invalid_group_by_reference/postgresql/stderr.txt b/internal/endtoend/testdata/invalid_group_by_reference/postgresql/stderr.txt new file mode 100644 index 0000000000..e9dd12bc6f --- /dev/null +++ b/internal/endtoend/testdata/invalid_group_by_reference/postgresql/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:10:10: column reference "invalid_reference" not found diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 8e1247c048..053b1f549c 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -474,6 +474,7 @@ func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt { stmt := &ast.SelectStmt{ TargetList: c.convertFieldList(n.Fields), FromClause: c.convertTableRefsClause(n.From), + GroupClause: c.convertGroupByClause(n.GroupBy), WhereClause: c.convert(n.Where), WithClause: c.convertWithClause(n.With), WindowClause: windowClause, @@ -677,7 +678,14 @@ func (c *cc) convertBinlogStmt(n *pcast.BinlogStmt) ast.Node { } func (c *cc) convertByItem(n *pcast.ByItem) ast.Node { - return todo(n) + switch n.Expr.(type) { + case *pcast.PositionExpr: + return c.convertPositionExpr(n.Expr.(*pcast.PositionExpr)) + case *pcast.ColumnNameExpr: + return c.convertColumnNameExpr(n.Expr.(*pcast.ColumnNameExpr)) + default: + return todo(n) + } } func (c *cc) convertCaseExpr(n *pcast.CaseExpr) ast.Node { @@ -858,8 +866,19 @@ func (c *cc) convertGrantStmt(n *pcast.GrantStmt) ast.Node { return todo(n) } -func (c *cc) convertGroupByClause(n *pcast.GroupByClause) ast.Node { - return todo(n) +func (c *cc) convertGroupByClause(n *pcast.GroupByClause) *ast.List { + if n == nil { + return &ast.List{} + } + + var items []ast.Node + for _, item := range n.Items { + items = append(items, c.convertByItem(item)) + } + + return &ast.List{ + Items: items, + } } func (c *cc) convertHavingClause(n *pcast.HavingClause) ast.Node {