diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 199dc8ecfd..112d684764 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -102,6 +102,24 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { } switch n := res.Val.(type) { + case *ast.A_Const: + name := "" + if res.Name != nil { + name = *res.Name + } + switch n.Val.(type) { + case *ast.String: + cols = append(cols, &Column{Name: name, DataType: "text", NotNull: true}) + case *ast.Integer: + cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + case *ast.Float: + cols = append(cols, &Column{Name: name, DataType: "float", NotNull: true}) + case *ast.Boolean: + cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) + default: + cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false}) + } + case *ast.A_Expr: name := "" if res.Name != nil { diff --git a/internal/endtoend/testdata/diff_no_output/go/query.sql.go b/internal/endtoend/testdata/diff_no_output/go/query.sql.go index 7b9e9fe300..e1a78b741d 100644 --- a/internal/endtoend/testdata/diff_no_output/go/query.sql.go +++ b/internal/endtoend/testdata/diff_no_output/go/query.sql.go @@ -75,9 +75,9 @@ const selectOne = `-- name: SelectOne :one SELECT 1 ` -func (q *Queries) SelectOne(ctx context.Context) (interface{}, error) { +func (q *Queries) SelectOne(ctx context.Context) (int32, error) { row := q.db.QueryRowContext(ctx, selectOne) - var column_1 interface{} + var column_1 int32 err := row.Scan(&column_1) return column_1, err } diff --git a/internal/endtoend/testdata/diff_output/stderr.txt b/internal/endtoend/testdata/diff_output/stderr.txt index 52dcd38708..4db48d8d44 100644 --- a/internal/endtoend/testdata/diff_output/stderr.txt +++ b/internal/endtoend/testdata/diff_output/stderr.txt @@ -46,9 +46,9 @@ +SELECT 1 +` + -+func (q *Queries) SelectOne(ctx context.Context) (interface{}, error) { ++func (q *Queries) SelectOne(ctx context.Context) (int32, error) { + row := q.db.QueryRowContext(ctx, selectOne) -+ var column_1 interface{} ++ var column_1 int32 + err := row.Scan(&column_1) + return column_1, err +} diff --git a/internal/endtoend/testdata/selectstatic/mysql/go/db.go b/internal/endtoend/testdata/selectstatic/mysql/go/db.go new file mode 100644 index 0000000000..02974bda59 --- /dev/null +++ b/internal/endtoend/testdata/selectstatic/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/selectstatic/mysql/go/models.go b/internal/endtoend/testdata/selectstatic/mysql/go/models.go new file mode 100644 index 0000000000..259281c7b7 --- /dev/null +++ b/internal/endtoend/testdata/selectstatic/mysql/go/models.go @@ -0,0 +1,7 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package querytest + +import () diff --git a/internal/endtoend/testdata/selectstatic/mysql/go/query.sql.go b/internal/endtoend/testdata/selectstatic/mysql/go/query.sql.go new file mode 100644 index 0000000000..7225a40ea7 --- /dev/null +++ b/internal/endtoend/testdata/selectstatic/mysql/go/query.sql.go @@ -0,0 +1,35 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: query.sql + +package querytest + +import ( + "context" +) + +const selectStatic = `-- name: SelectStatic :one +SELECT 'a', 'b' AS b, 1 AS num, true AS truefield, 1.0 AS floater +` + +type SelectStaticRow struct { + Column1 string + B string + Num int32 + Truefield int32 + Floater float64 +} + +func (q *Queries) SelectStatic(ctx context.Context) (SelectStaticRow, error) { + row := q.db.QueryRowContext(ctx, selectStatic) + var i SelectStaticRow + err := row.Scan( + &i.Column1, + &i.B, + &i.Num, + &i.Truefield, + &i.Floater, + ) + return i, err +} diff --git a/internal/endtoend/testdata/selectstatic/mysql/query.sql b/internal/endtoend/testdata/selectstatic/mysql/query.sql new file mode 100644 index 0000000000..3a184df07c --- /dev/null +++ b/internal/endtoend/testdata/selectstatic/mysql/query.sql @@ -0,0 +1,2 @@ +-- name: SelectStatic :one +SELECT 'a', 'b' AS b, 1 AS num, true AS truefield, 1.0 AS floater diff --git a/internal/endtoend/testdata/selectstatic/mysql/sqlc.json b/internal/endtoend/testdata/selectstatic/mysql/sqlc.json new file mode 100644 index 0000000000..3d928ae137 --- /dev/null +++ b/internal/endtoend/testdata/selectstatic/mysql/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "sql_package": "database/sql", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/show_warnings/mysql/go/db.go b/internal/endtoend/testdata/show_warnings/mysql/go/db.go new file mode 100644 index 0000000000..02974bda59 --- /dev/null +++ b/internal/endtoend/testdata/show_warnings/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/show_warnings/mysql/go/models.go b/internal/endtoend/testdata/show_warnings/mysql/go/models.go new file mode 100644 index 0000000000..259281c7b7 --- /dev/null +++ b/internal/endtoend/testdata/show_warnings/mysql/go/models.go @@ -0,0 +1,7 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package querytest + +import () diff --git a/internal/endtoend/testdata/show_warnings/mysql/go/query.sql.go b/internal/endtoend/testdata/show_warnings/mysql/go/query.sql.go new file mode 100644 index 0000000000..4248802b92 --- /dev/null +++ b/internal/endtoend/testdata/show_warnings/mysql/go/query.sql.go @@ -0,0 +1,43 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: query.sql + +package querytest + +import ( + "context" +) + +const showWarnings = `-- name: ShowWarnings :many +SHOW WARNINGS +` + +type ShowWarningsRow struct { + Level string + Code int32 + Message string +} + +func (q *Queries) ShowWarnings(ctx context.Context) ([]ShowWarningsRow, error) { + rows, err := q.db.QueryContext(ctx, showWarnings) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ShowWarningsRow + for rows.Next() { + var i ShowWarningsRow + if err := rows.Scan(&i.Level, &i.Code, &i.Message); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/show_warnings/mysql/query.sql b/internal/endtoend/testdata/show_warnings/mysql/query.sql new file mode 100644 index 0000000000..fa67110549 --- /dev/null +++ b/internal/endtoend/testdata/show_warnings/mysql/query.sql @@ -0,0 +1,2 @@ +-- name: ShowWarnings :many +SHOW WARNINGS; diff --git a/internal/endtoend/testdata/show_warnings/mysql/sqlc.json b/internal/endtoend/testdata/show_warnings/mysql/sqlc.json new file mode 100644 index 0000000000..445bbd1589 --- /dev/null +++ b/internal/endtoend/testdata/show_warnings/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 48a2714a5b..56fad07b53 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -6,6 +6,7 @@ import ( "strings" pcast "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/opcode" driver "github.com/pingcap/tidb/parser/test_driver" "github.com/pingcap/tidb/parser/types" @@ -580,6 +581,41 @@ func (c *cc) convertUpdateStmt(n *pcast.UpdateStmt) *ast.UpdateStmt { } func (c *cc) convertValueExpr(n *driver.ValueExpr) *ast.A_Const { + switch n.TexprNode.Type.GetType() { + case mysql.TypeBit: + case mysql.TypeDate: + case mysql.TypeDatetime: + case mysql.TypeGeometry: + case mysql.TypeJSON: + case mysql.TypeNull: + case mysql.TypeSet: + case mysql.TypeShort: + case mysql.TypeDuration: + case mysql.TypeTimestamp: + // TODO: Create an AST type for these? + + case mysql.TypeTiny, + mysql.TypeInt24, + mysql.TypeYear, + mysql.TypeLong, + mysql.TypeLonglong: + return &ast.A_Const{ + Val: &ast.Integer{ + Ival: n.Datum.GetInt64(), + }, + } + + case mysql.TypeDouble, + mysql.TypeFloat, + mysql.TypeNewDecimal: + return &ast.A_Const{ + Val: &ast.Float{ + // TODO: Extract the value from n.TexprNode + }, + } + + case mysql.TypeBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeLongBlob, mysql.TypeMediumBlob, mysql.TypeTinyBlob, mysql.TypeEnum: + } return &ast.A_Const{ Val: &ast.String{ Str: n.Datum.GetString(), @@ -1219,7 +1255,32 @@ func (c *cc) convertSetStmt(n *pcast.SetStmt) ast.Node { } func (c *cc) convertShowStmt(n *pcast.ShowStmt) ast.Node { - return todo(n) + if n.Tp != pcast.ShowWarnings { + return todo(n) + } + level := "level" + code := "code" + message := "message" + stmt := &ast.SelectStmt{ + FromClause: &ast.List{}, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Name: &level, + Val: &ast.A_Const{Val: &ast.String{}}, + }, + &ast.ResTarget{ + Name: &code, + Val: &ast.A_Const{Val: &ast.Integer{}}, + }, + &ast.ResTarget{ + Name: &message, + Val: &ast.A_Const{Val: &ast.String{}}, + }, + }, + }, + } + return stmt } func (c *cc) convertShutdownStmt(n *pcast.ShutdownStmt) ast.Node {