From 9475610ef7415aaba09b87dd782c38d8fab26a49 Mon Sep 17 00:00:00 2001 From: Antoine GIRARD Date: Thu, 31 Aug 2023 00:23:27 +0200 Subject: [PATCH 1/6] setup expected to pass test --- .../testdata/enum_alter/mysql/go/db.go | 31 ++++++++++ .../testdata/enum_alter/mysql/go/models.go | 57 +++++++++++++++++++ .../testdata/enum_alter/mysql/go/query.sql.go | 37 ++++++++++++ .../testdata/enum_alter/mysql/query.sql | 2 + .../testdata/enum_alter/mysql/schema.sql | 9 +++ .../testdata/enum_alter/mysql/sqlc.json | 13 +++++ 6 files changed, 149 insertions(+) create mode 100644 internal/endtoend/testdata/enum_alter/mysql/go/db.go create mode 100644 internal/endtoend/testdata/enum_alter/mysql/go/models.go create mode 100644 internal/endtoend/testdata/enum_alter/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/enum_alter/mysql/query.sql create mode 100644 internal/endtoend/testdata/enum_alter/mysql/schema.sql create mode 100644 internal/endtoend/testdata/enum_alter/mysql/sqlc.json diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/db.go b/internal/endtoend/testdata/enum_alter/mysql/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/models.go b/internal/endtoend/testdata/enum_alter/mysql/go/models.go new file mode 100644 index 0000000000..a329fe220a --- /dev/null +++ b/internal/endtoend/testdata/enum_alter/mysql/go/models.go @@ -0,0 +1,57 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type AuthorsStatus string + +const ( + AuthorsStatusOk AuthorsStatus = "ok" + AuthorsStatusInit AuthorsStatus = "init" +) + +func (e *AuthorsStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AuthorsStatus(s) + case string: + *e = AuthorsStatus(s) + default: + return fmt.Errorf("unsupported scan type for AuthorsStatus: %T", src) + } + return nil +} + +type NullAuthorsStatus struct { + AuthorsStatus AuthorsStatus + Valid bool // Valid is true if AuthorsStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAuthorsStatus) Scan(value interface{}) error { + if value == nil { + ns.AuthorsStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AuthorsStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAuthorsStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AuthorsStatus), nil +} + +type Author struct { + ID int64 + Status AuthorsStatus +} diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/query.sql.go b/internal/endtoend/testdata/enum_alter/mysql/go/query.sql.go new file mode 100644 index 0000000000..ccac241cd6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter/mysql/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +select id, status from authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/enum_alter/mysql/query.sql b/internal/endtoend/testdata/enum_alter/mysql/query.sql new file mode 100644 index 0000000000..0b16d94be6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter/mysql/query.sql @@ -0,0 +1,2 @@ +-- name: ListAuthors :many +select * from authors; diff --git a/internal/endtoend/testdata/enum_alter/mysql/schema.sql b/internal/endtoend/testdata/enum_alter/mysql/schema.sql new file mode 100644 index 0000000000..cc649e55d3 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter/mysql/schema.sql @@ -0,0 +1,9 @@ +-- similar to issue https://github.com/sqlc-dev/sqlc/issues/1503 + +CREATE TABLE authors ( + id bigint primary key, + status enum("ok", "init") default "init" not null +); + +-- remove this alter to see the change in models.go +-- ALTER TABLE authors MODIFY status enum('init', 'done', 'canceled', 'processing', 'waiting') default "init" not null; \ No newline at end of file diff --git a/internal/endtoend/testdata/enum_alter/mysql/sqlc.json b/internal/endtoend/testdata/enum_alter/mysql/sqlc.json new file mode 100644 index 0000000000..feb988c2be --- /dev/null +++ b/internal/endtoend/testdata/enum_alter/mysql/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "omit_unused_structs": true + } + ] +} From c4616a94eec6210ee784addec7e26828c302c92c Mon Sep 17 00:00:00 2001 From: Antoine GIRARD Date: Thu, 31 Aug 2023 00:25:10 +0200 Subject: [PATCH 2/6] with alter result --- .../testdata/enum_alter/mysql/go/models.go | 49 +------------------ .../testdata/enum_alter/mysql/schema.sql | 2 +- 2 files changed, 3 insertions(+), 48 deletions(-) diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/models.go b/internal/endtoend/testdata/enum_alter/mysql/go/models.go index a329fe220a..6d04b43339 100644 --- a/internal/endtoend/testdata/enum_alter/mysql/go/models.go +++ b/internal/endtoend/testdata/enum_alter/mysql/go/models.go @@ -4,54 +4,9 @@ package querytest -import ( - "database/sql/driver" - "fmt" -) - -type AuthorsStatus string - -const ( - AuthorsStatusOk AuthorsStatus = "ok" - AuthorsStatusInit AuthorsStatus = "init" -) - -func (e *AuthorsStatus) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - *e = AuthorsStatus(s) - case string: - *e = AuthorsStatus(s) - default: - return fmt.Errorf("unsupported scan type for AuthorsStatus: %T", src) - } - return nil -} - -type NullAuthorsStatus struct { - AuthorsStatus AuthorsStatus - Valid bool // Valid is true if AuthorsStatus is not NULL -} - -// Scan implements the Scanner interface. -func (ns *NullAuthorsStatus) Scan(value interface{}) error { - if value == nil { - ns.AuthorsStatus, ns.Valid = "", false - return nil - } - ns.Valid = true - return ns.AuthorsStatus.Scan(value) -} - -// Value implements the driver Valuer interface. -func (ns NullAuthorsStatus) Value() (driver.Value, error) { - if !ns.Valid { - return nil, nil - } - return string(ns.AuthorsStatus), nil -} +import () type Author struct { ID int64 - Status AuthorsStatus + Status string } diff --git a/internal/endtoend/testdata/enum_alter/mysql/schema.sql b/internal/endtoend/testdata/enum_alter/mysql/schema.sql index cc649e55d3..d199af185f 100644 --- a/internal/endtoend/testdata/enum_alter/mysql/schema.sql +++ b/internal/endtoend/testdata/enum_alter/mysql/schema.sql @@ -6,4 +6,4 @@ CREATE TABLE authors ( ); -- remove this alter to see the change in models.go --- ALTER TABLE authors MODIFY status enum('init', 'done', 'canceled', 'processing', 'waiting') default "init" not null; \ No newline at end of file +ALTER TABLE authors MODIFY status enum('init', 'done', 'canceled', 'processing', 'waiting') default "init" not null; \ No newline at end of file From f0b46ea61dc2fba888616f160832afa35ce0b067 Mon Sep 17 00:00:00 2001 From: Antoine GIRARD Date: Thu, 31 Aug 2023 02:01:16 +0200 Subject: [PATCH 3/6] POC --- internal/compiler/compile.go | 3 + .../testdata/enum_alter/mysql/go/models.go | 52 ++++++++- internal/engine/dolphin/convert.go | 105 ++++++++---------- internal/sql/catalog/table.go | 69 +++++++----- internal/sql/catalog/types.go | 36 ++++++ 5 files changed, 173 insertions(+), 92 deletions(-) diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 5cbfab674a..efd8cf8086 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "log" "os" "path/filepath" "strings" @@ -43,7 +44,9 @@ func (c *Compiler) parseCatalog(schemas []string) error { merr.Add(filename, contents, 0, err) continue } + for i := range stmts { + log.Printf("stmts[%d]: %#v", i, stmts[i].Raw.Stmt) if err := c.catalog.Update(stmts[i], c); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err) continue diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/models.go b/internal/endtoend/testdata/enum_alter/mysql/go/models.go index 6d04b43339..3af0810506 100644 --- a/internal/endtoend/testdata/enum_alter/mysql/go/models.go +++ b/internal/endtoend/testdata/enum_alter/mysql/go/models.go @@ -4,9 +4,57 @@ package querytest -import () +import ( + "database/sql/driver" + "fmt" +) + +type AuthorsStatus string + +const ( + AuthorsStatusInit AuthorsStatus = "init" + AuthorsStatusDone AuthorsStatus = "done" + AuthorsStatusCanceled AuthorsStatus = "canceled" + AuthorsStatusProcessing AuthorsStatus = "processing" + AuthorsStatusWaiting AuthorsStatus = "waiting" +) + +func (e *AuthorsStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AuthorsStatus(s) + case string: + *e = AuthorsStatus(s) + default: + return fmt.Errorf("unsupported scan type for AuthorsStatus: %T", src) + } + return nil +} + +type NullAuthorsStatus struct { + AuthorsStatus AuthorsStatus + Valid bool // Valid is true if AuthorsStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAuthorsStatus) Scan(value interface{}) error { + if value == nil { + ns.AuthorsStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AuthorsStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAuthorsStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AuthorsStatus), nil +} type Author struct { ID int64 - Status string + Status AuthorsStatus } diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 5cca536976..0915fa12c0 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -43,16 +43,7 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { case pcast.AlterTableAddColumns: for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } + columnDef := convertColumnDef(def) alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, @@ -77,36 +68,20 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } + columnDef := convertColumnDef(def) alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, Def: &columnDef, }) + + log.Printf("CHANGE COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals) } case pcast.AlterTableModifyColumn: for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } + columnDef := convertColumnDef(def) alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_DropColumn, @@ -116,6 +91,8 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { Subtype: ast.AT_AddColumn, Def: &columnDef, }) + + log.Printf("MODIFY COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals) } case pcast.AlterTableAlterColumn: @@ -249,36 +226,9 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { create.ReferTable = parseTableName(n.ReferTable) } for _, def := range n.Cols { - var vals *ast.List - if len(def.Tp.GetElems()) > 0 { - vals = &ast.List{} - for i := range def.Tp.GetElems() { - vals.Items = append(vals.Items, &ast.String{ - Str: def.Tp.GetElems()[i], - }) - } - } - comment := "" - for _, opt := range def.Options { - switch opt.Tp { - case pcast.ColumnOptionComment: - if value, ok := opt.Expr.(*driver.ValueExpr); ok { - comment = value.GetString() - } - } - } - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - Comment: comment, - Vals: vals, - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } + columnDef := convertColumnDef(def) + + log.Printf("CREATE COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals) create.Cols = append(create.Cols, &columnDef) } for _, opt := range n.Options { @@ -290,6 +240,41 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { return create } +func convertColumnDef(def *pcast.ColumnDef) ast.ColumnDef { + var vals *ast.List + if len(def.Tp.GetElems()) > 0 { + vals = &ast.List{} + for i := range def.Tp.GetElems() { + vals.Items = append(vals.Items, &ast.String{ + Str: def.Tp.GetElems()[i], + }) + } + } + comment := "" + for _, opt := range def.Options { + switch opt.Tp { + case pcast.ColumnOptionComment: + if value, ok := opt.Expr.(*driver.ValueExpr); ok { + comment = value.GetString() + } + } + } + columnDef := ast.ColumnDef{ + Colname: def.Name.String(), + TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, + IsNotNull: isNotNull(def), + IsUnsigned: isUnsigned(def), + Comment: comment, + Vals: vals, + } + if def.Tp.GetFlen() >= 0 { + length := def.Tp.GetFlen() + columnDef.Length = &length + } + + return columnDef +} + func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { var items []ast.Node if schema := n.Name.Schema.String(); schema != "" { diff --git a/internal/sql/catalog/table.go b/internal/sql/catalog/table.go index 5598da4df5..cb5741a98f 100644 --- a/internal/sql/catalog/table.go +++ b/internal/sql/catalog/table.go @@ -3,6 +3,7 @@ package catalog import ( "errors" "fmt" + "log" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" @@ -41,7 +42,7 @@ func (table *Table) isExistColumn(cmd *ast.AlterTableCmd) (int, error) { return -1, nil } -func (table *Table) addColumn(cmd *ast.AlterTableCmd) error { +func (table *Table) addColumn(c *Catalog, cmd *ast.AlterTableCmd) error { for _, c := range table.Columns { if c.Name == cmd.Def.Colname { if !cmd.MissingOk { @@ -51,15 +52,13 @@ func (table *Table) addColumn(cmd *ast.AlterTableCmd) error { } } - table.Columns = append(table.Columns, &Column{ - Name: cmd.Def.Colname, - Type: *cmd.Def.TypeName, - IsNotNull: cmd.Def.IsNotNull, - IsUnsigned: cmd.Def.IsUnsigned, - IsArray: cmd.Def.IsArray, - ArrayDims: cmd.Def.ArrayDims, - Length: cmd.Def.Length, - }) + tc, err := c.defToColumn(table.Rel, cmd.Def) + if err != nil { + return err + } + log.Printf("addColumn COLUMN: %#v\n%#v\n%#v", tc, tc.Type, cmd.Def.Vals) + + table.Columns = append(table.Columns, tc) return nil } @@ -187,7 +186,7 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error { case *ast.AlterTableCmd: switch cmd.Subtype { case ast.AT_AddColumn: - if err := table.addColumn(cmd); err != nil { + if err := table.addColumn(c, cmd); err != nil { return err } case ast.AT_AlterColumnType: @@ -305,26 +304,11 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { continue } - tc := &Column{ - Name: col.Colname, - Type: *col.TypeName, - IsNotNull: col.IsNotNull, - IsUnsigned: col.IsUnsigned, - IsArray: col.IsArray, - ArrayDims: col.ArrayDims, - Comment: col.Comment, - Length: col.Length, - } - if col.Vals != nil { - typeName := ast.TypeName{ - Name: fmt.Sprintf("%s_%s", stmt.Name.Name, col.Colname), - } - s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals} - if err := c.createEnum(s); err != nil { - return err - } - tc.Type = typeName + tc, err := c.defToColumn(stmt.Name, col) + if err != nil { + return err } + log.Printf("createTable COLUMN: %#v\n%#v\n%#v", tc, tc.Type, col.Vals) tbl.Columns = append(tbl.Columns, tc) } } @@ -340,6 +324,31 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { return nil } +func (c *Catalog) defToColumn(table *ast.TableName, col *ast.ColumnDef) (*Column, error) { + tc := &Column{ + Name: col.Colname, + Type: *col.TypeName, + IsNotNull: col.IsNotNull, + IsUnsigned: col.IsUnsigned, + IsArray: col.IsArray, + ArrayDims: col.ArrayDims, + Comment: col.Comment, + Length: col.Length, + } + if col.Vals != nil { + typeName := ast.TypeName{ + Name: fmt.Sprintf("%s_%s", table.Name, col.Colname), + } + s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals} + if err := c.createOrSetEnum(s); err != nil { + return nil, err + } + tc.Type = typeName + } + + return tc, nil +} + func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error { for _, name := range stmt.Tables { ns := name.Schema diff --git a/internal/sql/catalog/types.go b/internal/sql/catalog/types.go index e92a3a219e..9bef5ef7d1 100644 --- a/internal/sql/catalog/types.go +++ b/internal/sql/catalog/types.go @@ -3,6 +3,7 @@ package catalog import ( "errors" "fmt" + "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -90,6 +91,41 @@ func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error { return nil } +func (c *Catalog) createOrSetEnum(stmt *ast.CreateEnumStmt) error { + ns := stmt.TypeName.Schema + if ns == "" { + ns = c.DefaultSchema + } + schema, err := c.getSchema(ns) + if err != nil { + return err + } + // Because tables have associated data types, the type name must also + // be distinct from the name of any existing table in the same + // schema. + // https://www.postgresql.org/docs/current/sql-createtype.html + tbl := &ast.TableName{ + Name: stmt.TypeName.Name, + } + if _, _, err := schema.getTable(tbl); err == nil { + return sqlerr.RelationExists(tbl.Name) + } + if typ, _, err := schema.getType(stmt.TypeName); err == nil { + enum, ok := typ.(*Enum) + if !ok { + return fmt.Errorf("type is not an enum: %s", stmt.TypeName.Name) + } + enum.Vals = stringSlice(stmt.Vals) + + return nil + } + schema.Types = append(schema.Types, &Enum{ + Name: stmt.TypeName.Name, + Vals: stringSlice(stmt.Vals), + }) + return nil +} + func stringSlice(list *ast.List) []string { items := []string{} for _, item := range list.Items { From 80a339895b44432eaba87de722806b26c1ffcc84 Mon Sep 17 00:00:00 2001 From: Antoine GIRARD Date: Thu, 31 Aug 2023 02:06:48 +0200 Subject: [PATCH 4/6] cleanup --- internal/compiler/compile.go | 2 -- internal/engine/dolphin/convert.go | 22 +++++-------------- internal/sql/catalog/catalog.go | 2 +- internal/sql/catalog/table.go | 5 +---- internal/sql/catalog/types.go | 35 +++++------------------------- 5 files changed, 13 insertions(+), 53 deletions(-) diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index efd8cf8086..7cc4274b7b 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "io" - "log" "os" "path/filepath" "strings" @@ -46,7 +45,6 @@ func (c *Compiler) parseCatalog(schemas []string) error { } for i := range stmts { - log.Printf("stmts[%d]: %#v", i, stmts[i].Raw.Stmt) if err := c.catalog.Update(stmts[i], c); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err) continue diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 0915fa12c0..a167858e4c 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -43,11 +43,10 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { case pcast.AlterTableAddColumns: for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := convertColumnDef(def) alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &columnDef, + Def: convertColumnDef(def), }) } @@ -68,20 +67,16 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := convertColumnDef(def) alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &columnDef, + Def: convertColumnDef(def), }) - - log.Printf("CHANGE COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals) } case pcast.AlterTableModifyColumn: for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := convertColumnDef(def) alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_DropColumn, @@ -89,10 +84,8 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &columnDef, + Def: convertColumnDef(def), }) - - log.Printf("MODIFY COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals) } case pcast.AlterTableAlterColumn: @@ -226,10 +219,7 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { create.ReferTable = parseTableName(n.ReferTable) } for _, def := range n.Cols { - columnDef := convertColumnDef(def) - - log.Printf("CREATE COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals) - create.Cols = append(create.Cols, &columnDef) + create.Cols = append(create.Cols, convertColumnDef(def)) } for _, opt := range n.Options { switch opt.Tp { @@ -240,7 +230,7 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { return create } -func convertColumnDef(def *pcast.ColumnDef) ast.ColumnDef { +func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef { var vals *ast.List if len(def.Tp.GetElems()) > 0 { vals = &ast.List{} @@ -272,7 +262,7 @@ func convertColumnDef(def *pcast.ColumnDef) ast.ColumnDef { columnDef.Length = &length } - return columnDef + return &columnDef } func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 278ea8797d..9a4a29e880 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -83,7 +83,7 @@ func (c *Catalog) Update(stmt ast.Statement, colGen columnGenerator) error { err = c.createCompositeType(n) case *ast.CreateEnumStmt: - err = c.createEnum(n) + err = c.createEnum(n, false) case *ast.CreateExtensionStmt: err = c.createExtension(n) diff --git a/internal/sql/catalog/table.go b/internal/sql/catalog/table.go index cb5741a98f..1c9bacf994 100644 --- a/internal/sql/catalog/table.go +++ b/internal/sql/catalog/table.go @@ -3,7 +3,6 @@ package catalog import ( "errors" "fmt" - "log" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" @@ -56,7 +55,6 @@ func (table *Table) addColumn(c *Catalog, cmd *ast.AlterTableCmd) error { if err != nil { return err } - log.Printf("addColumn COLUMN: %#v\n%#v\n%#v", tc, tc.Type, cmd.Def.Vals) table.Columns = append(table.Columns, tc) return nil @@ -308,7 +306,6 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { if err != nil { return err } - log.Printf("createTable COLUMN: %#v\n%#v\n%#v", tc, tc.Type, col.Vals) tbl.Columns = append(tbl.Columns, tc) } } @@ -340,7 +337,7 @@ func (c *Catalog) defToColumn(table *ast.TableName, col *ast.ColumnDef) (*Column Name: fmt.Sprintf("%s_%s", table.Name, col.Colname), } s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals} - if err := c.createOrSetEnum(s); err != nil { + if err := c.createEnum(s, true); err != nil { return nil, err } tc.Type = typeName diff --git a/internal/sql/catalog/types.go b/internal/sql/catalog/types.go index 9bef5ef7d1..2bb1033b3a 100644 --- a/internal/sql/catalog/types.go +++ b/internal/sql/catalog/types.go @@ -62,36 +62,7 @@ func sameType(a, b *ast.TypeName) bool { return true } -func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error { - ns := stmt.TypeName.Schema - if ns == "" { - ns = c.DefaultSchema - } - schema, err := c.getSchema(ns) - if err != nil { - return err - } - // Because tables have associated data types, the type name must also - // be distinct from the name of any existing table in the same - // schema. - // https://www.postgresql.org/docs/current/sql-createtype.html - tbl := &ast.TableName{ - Name: stmt.TypeName.Name, - } - if _, _, err := schema.getTable(tbl); err == nil { - return sqlerr.RelationExists(tbl.Name) - } - if _, _, err := schema.getType(stmt.TypeName); err == nil { - return sqlerr.TypeExists(tbl.Name) - } - schema.Types = append(schema.Types, &Enum{ - Name: stmt.TypeName.Name, - Vals: stringSlice(stmt.Vals), - }) - return nil -} - -func (c *Catalog) createOrSetEnum(stmt *ast.CreateEnumStmt) error { +func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt, overwrite bool) error { ns := stmt.TypeName.Schema if ns == "" { ns = c.DefaultSchema @@ -111,6 +82,10 @@ func (c *Catalog) createOrSetEnum(stmt *ast.CreateEnumStmt) error { return sqlerr.RelationExists(tbl.Name) } if typ, _, err := schema.getType(stmt.TypeName); err == nil { + if !overwrite { + return sqlerr.TypeExists(tbl.Name) + } + enum, ok := typ.(*Enum) if !ok { return fmt.Errorf("type is not an enum: %s", stmt.TypeName.Name) From 036f75edbdcc4868b9754e9c61a158aa4ea9dec3 Mon Sep 17 00:00:00 2001 From: Antoine GIRARD Date: Thu, 31 Aug 2023 02:09:56 +0200 Subject: [PATCH 5/6] add alter change test case --- .../mysql/go/db.go | 0 .../mysql/go/models.go | 0 .../mysql/go/query.sql.go | 0 .../mysql/query.sql | 0 .../enum_alter_change/mysql/schema.sql | 9 +++ .../mysql/sqlc.json | 0 .../testdata/enum_alter_modify/mysql/go/db.go | 31 ++++++++++ .../enum_alter_modify/mysql/go/models.go | 60 +++++++++++++++++++ .../enum_alter_modify/mysql/go/query.sql.go | 37 ++++++++++++ .../enum_alter_modify/mysql/query.sql | 2 + .../mysql/schema.sql | 0 .../enum_alter_modify/mysql/sqlc.json | 13 ++++ 12 files changed, 152 insertions(+) rename internal/endtoend/testdata/{enum_alter => enum_alter_change}/mysql/go/db.go (100%) rename internal/endtoend/testdata/{enum_alter => enum_alter_change}/mysql/go/models.go (100%) rename internal/endtoend/testdata/{enum_alter => enum_alter_change}/mysql/go/query.sql.go (100%) rename internal/endtoend/testdata/{enum_alter => enum_alter_change}/mysql/query.sql (100%) create mode 100644 internal/endtoend/testdata/enum_alter_change/mysql/schema.sql rename internal/endtoend/testdata/{enum_alter => enum_alter_change}/mysql/sqlc.json (100%) create mode 100644 internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go create mode 100644 internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go create mode 100644 internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/enum_alter_modify/mysql/query.sql rename internal/endtoend/testdata/{enum_alter => enum_alter_modify}/mysql/schema.sql (100%) create mode 100644 internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/db.go b/internal/endtoend/testdata/enum_alter_change/mysql/go/db.go similarity index 100% rename from internal/endtoend/testdata/enum_alter/mysql/go/db.go rename to internal/endtoend/testdata/enum_alter_change/mysql/go/db.go diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/models.go b/internal/endtoend/testdata/enum_alter_change/mysql/go/models.go similarity index 100% rename from internal/endtoend/testdata/enum_alter/mysql/go/models.go rename to internal/endtoend/testdata/enum_alter_change/mysql/go/models.go diff --git a/internal/endtoend/testdata/enum_alter/mysql/go/query.sql.go b/internal/endtoend/testdata/enum_alter_change/mysql/go/query.sql.go similarity index 100% rename from internal/endtoend/testdata/enum_alter/mysql/go/query.sql.go rename to internal/endtoend/testdata/enum_alter_change/mysql/go/query.sql.go diff --git a/internal/endtoend/testdata/enum_alter/mysql/query.sql b/internal/endtoend/testdata/enum_alter_change/mysql/query.sql similarity index 100% rename from internal/endtoend/testdata/enum_alter/mysql/query.sql rename to internal/endtoend/testdata/enum_alter_change/mysql/query.sql diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/schema.sql b/internal/endtoend/testdata/enum_alter_change/mysql/schema.sql new file mode 100644 index 0000000000..67c89a0eff --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/schema.sql @@ -0,0 +1,9 @@ +-- similar to issue https://github.com/sqlc-dev/sqlc/issues/1503 + +CREATE TABLE authors ( + id bigint primary key, + status enum("ok", "init") default "init" not null +); + +-- remove this alter to see the change in models.go +ALTER TABLE authors CHANGE status status enum('init', 'done', 'canceled', 'processing', 'waiting') default "init" not null; \ No newline at end of file diff --git a/internal/endtoend/testdata/enum_alter/mysql/sqlc.json b/internal/endtoend/testdata/enum_alter_change/mysql/sqlc.json similarity index 100% rename from internal/endtoend/testdata/enum_alter/mysql/sqlc.json rename to internal/endtoend/testdata/enum_alter_change/mysql/sqlc.json diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go b/internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go b/internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go new file mode 100644 index 0000000000..3af0810506 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go @@ -0,0 +1,60 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type AuthorsStatus string + +const ( + AuthorsStatusInit AuthorsStatus = "init" + AuthorsStatusDone AuthorsStatus = "done" + AuthorsStatusCanceled AuthorsStatus = "canceled" + AuthorsStatusProcessing AuthorsStatus = "processing" + AuthorsStatusWaiting AuthorsStatus = "waiting" +) + +func (e *AuthorsStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AuthorsStatus(s) + case string: + *e = AuthorsStatus(s) + default: + return fmt.Errorf("unsupported scan type for AuthorsStatus: %T", src) + } + return nil +} + +type NullAuthorsStatus struct { + AuthorsStatus AuthorsStatus + Valid bool // Valid is true if AuthorsStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAuthorsStatus) Scan(value interface{}) error { + if value == nil { + ns.AuthorsStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AuthorsStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAuthorsStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AuthorsStatus), nil +} + +type Author struct { + ID int64 + Status AuthorsStatus +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go b/internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go new file mode 100644 index 0000000000..ccac241cd6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +select id, status from authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/query.sql b/internal/endtoend/testdata/enum_alter_modify/mysql/query.sql new file mode 100644 index 0000000000..0b16d94be6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/query.sql @@ -0,0 +1,2 @@ +-- name: ListAuthors :many +select * from authors; diff --git a/internal/endtoend/testdata/enum_alter/mysql/schema.sql b/internal/endtoend/testdata/enum_alter_modify/mysql/schema.sql similarity index 100% rename from internal/endtoend/testdata/enum_alter/mysql/schema.sql rename to internal/endtoend/testdata/enum_alter_modify/mysql/schema.sql diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json b/internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json new file mode 100644 index 0000000000..feb988c2be --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "omit_unused_structs": true + } + ] +} From 6265431cc3e764bcbf2aa661896389f93e523061 Mon Sep 17 00:00:00 2001 From: Antoine GIRARD Date: Thu, 31 Aug 2023 02:14:54 +0200 Subject: [PATCH 6/6] cleanup --- internal/compiler/compile.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 7cc4274b7b..5cbfab674a 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -43,7 +43,6 @@ func (c *Compiler) parseCatalog(schemas []string) error { merr.Add(filename, contents, 0, err) continue } - for i := range stmts { if err := c.catalog.Update(stmts[i], c); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err)