diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/go/db.go b/internal/endtoend/testdata/enum_alter_change/mysql/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/go/models.go b/internal/endtoend/testdata/enum_alter_change/mysql/go/models.go new file mode 100644 index 0000000000..3af0810506 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/go/models.go @@ -0,0 +1,60 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type AuthorsStatus string + +const ( + AuthorsStatusInit AuthorsStatus = "init" + AuthorsStatusDone AuthorsStatus = "done" + AuthorsStatusCanceled AuthorsStatus = "canceled" + AuthorsStatusProcessing AuthorsStatus = "processing" + AuthorsStatusWaiting AuthorsStatus = "waiting" +) + +func (e *AuthorsStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AuthorsStatus(s) + case string: + *e = AuthorsStatus(s) + default: + return fmt.Errorf("unsupported scan type for AuthorsStatus: %T", src) + } + return nil +} + +type NullAuthorsStatus struct { + AuthorsStatus AuthorsStatus + Valid bool // Valid is true if AuthorsStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAuthorsStatus) Scan(value interface{}) error { + if value == nil { + ns.AuthorsStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AuthorsStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAuthorsStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AuthorsStatus), nil +} + +type Author struct { + ID int64 + Status AuthorsStatus +} diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/go/query.sql.go b/internal/endtoend/testdata/enum_alter_change/mysql/go/query.sql.go new file mode 100644 index 0000000000..ccac241cd6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +select id, status from authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/query.sql b/internal/endtoend/testdata/enum_alter_change/mysql/query.sql new file mode 100644 index 0000000000..0b16d94be6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/query.sql @@ -0,0 +1,2 @@ +-- name: ListAuthors :many +select * from authors; diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/schema.sql b/internal/endtoend/testdata/enum_alter_change/mysql/schema.sql new file mode 100644 index 0000000000..67c89a0eff --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/schema.sql @@ -0,0 +1,9 @@ +-- similar to issue https://github.com/sqlc-dev/sqlc/issues/1503 + +CREATE TABLE authors ( + id bigint primary key, + status enum("ok", "init") default "init" not null +); + +-- remove this alter to see the change in models.go +ALTER TABLE authors CHANGE status status enum('init', 'done', 'canceled', 'processing', 'waiting') default "init" not null; \ No newline at end of file diff --git a/internal/endtoend/testdata/enum_alter_change/mysql/sqlc.json b/internal/endtoend/testdata/enum_alter_change/mysql/sqlc.json new file mode 100644 index 0000000000..feb988c2be --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_change/mysql/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "omit_unused_structs": true + } + ] +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go b/internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go b/internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go new file mode 100644 index 0000000000..3af0810506 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/go/models.go @@ -0,0 +1,60 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type AuthorsStatus string + +const ( + AuthorsStatusInit AuthorsStatus = "init" + AuthorsStatusDone AuthorsStatus = "done" + AuthorsStatusCanceled AuthorsStatus = "canceled" + AuthorsStatusProcessing AuthorsStatus = "processing" + AuthorsStatusWaiting AuthorsStatus = "waiting" +) + +func (e *AuthorsStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AuthorsStatus(s) + case string: + *e = AuthorsStatus(s) + default: + return fmt.Errorf("unsupported scan type for AuthorsStatus: %T", src) + } + return nil +} + +type NullAuthorsStatus struct { + AuthorsStatus AuthorsStatus + Valid bool // Valid is true if AuthorsStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAuthorsStatus) Scan(value interface{}) error { + if value == nil { + ns.AuthorsStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AuthorsStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAuthorsStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AuthorsStatus), nil +} + +type Author struct { + ID int64 + Status AuthorsStatus +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go b/internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go new file mode 100644 index 0000000000..ccac241cd6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthors = `-- name: ListAuthors :many +select id, status from authors +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Status); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/query.sql b/internal/endtoend/testdata/enum_alter_modify/mysql/query.sql new file mode 100644 index 0000000000..0b16d94be6 --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/query.sql @@ -0,0 +1,2 @@ +-- name: ListAuthors :many +select * from authors; diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/schema.sql b/internal/endtoend/testdata/enum_alter_modify/mysql/schema.sql new file mode 100644 index 0000000000..d199af185f --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/schema.sql @@ -0,0 +1,9 @@ +-- similar to issue https://github.com/sqlc-dev/sqlc/issues/1503 + +CREATE TABLE authors ( + id bigint primary key, + status enum("ok", "init") default "init" not null +); + +-- remove this alter to see the change in models.go +ALTER TABLE authors MODIFY status enum('init', 'done', 'canceled', 'processing', 'waiting') default "init" not null; \ No newline at end of file diff --git a/internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json b/internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json new file mode 100644 index 0000000000..feb988c2be --- /dev/null +++ b/internal/endtoend/testdata/enum_alter_modify/mysql/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "omit_unused_structs": true + } + ] +} diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 5cca536976..a167858e4c 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -43,20 +43,10 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { case pcast.AlterTableAddColumns: for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &columnDef, + Def: convertColumnDef(def), }) } @@ -77,36 +67,16 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &columnDef, + Def: convertColumnDef(def), }) } case pcast.AlterTableModifyColumn: for _, def := range spec.NewColumns { name := def.Name.String() - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_DropColumn, @@ -114,7 +84,7 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, - Def: &columnDef, + Def: convertColumnDef(def), }) } @@ -249,37 +219,7 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { create.ReferTable = parseTableName(n.ReferTable) } for _, def := range n.Cols { - var vals *ast.List - if len(def.Tp.GetElems()) > 0 { - vals = &ast.List{} - for i := range def.Tp.GetElems() { - vals.Items = append(vals.Items, &ast.String{ - Str: def.Tp.GetElems()[i], - }) - } - } - comment := "" - for _, opt := range def.Options { - switch opt.Tp { - case pcast.ColumnOptionComment: - if value, ok := opt.Expr.(*driver.ValueExpr); ok { - comment = value.GetString() - } - } - } - columnDef := ast.ColumnDef{ - Colname: def.Name.String(), - TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, - IsNotNull: isNotNull(def), - IsUnsigned: isUnsigned(def), - Comment: comment, - Vals: vals, - } - if def.Tp.GetFlen() >= 0 { - length := def.Tp.GetFlen() - columnDef.Length = &length - } - create.Cols = append(create.Cols, &columnDef) + create.Cols = append(create.Cols, convertColumnDef(def)) } for _, opt := range n.Options { switch opt.Tp { @@ -290,6 +230,41 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { return create } +func convertColumnDef(def *pcast.ColumnDef) *ast.ColumnDef { + var vals *ast.List + if len(def.Tp.GetElems()) > 0 { + vals = &ast.List{} + for i := range def.Tp.GetElems() { + vals.Items = append(vals.Items, &ast.String{ + Str: def.Tp.GetElems()[i], + }) + } + } + comment := "" + for _, opt := range def.Options { + switch opt.Tp { + case pcast.ColumnOptionComment: + if value, ok := opt.Expr.(*driver.ValueExpr); ok { + comment = value.GetString() + } + } + } + columnDef := ast.ColumnDef{ + Colname: def.Name.String(), + TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())}, + IsNotNull: isNotNull(def), + IsUnsigned: isUnsigned(def), + Comment: comment, + Vals: vals, + } + if def.Tp.GetFlen() >= 0 { + length := def.Tp.GetFlen() + columnDef.Length = &length + } + + return &columnDef +} + func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { var items []ast.Node if schema := n.Name.Schema.String(); schema != "" { diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 278ea8797d..9a4a29e880 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -83,7 +83,7 @@ func (c *Catalog) Update(stmt ast.Statement, colGen columnGenerator) error { err = c.createCompositeType(n) case *ast.CreateEnumStmt: - err = c.createEnum(n) + err = c.createEnum(n, false) case *ast.CreateExtensionStmt: err = c.createExtension(n) diff --git a/internal/sql/catalog/table.go b/internal/sql/catalog/table.go index 5598da4df5..1c9bacf994 100644 --- a/internal/sql/catalog/table.go +++ b/internal/sql/catalog/table.go @@ -41,7 +41,7 @@ func (table *Table) isExistColumn(cmd *ast.AlterTableCmd) (int, error) { return -1, nil } -func (table *Table) addColumn(cmd *ast.AlterTableCmd) error { +func (table *Table) addColumn(c *Catalog, cmd *ast.AlterTableCmd) error { for _, c := range table.Columns { if c.Name == cmd.Def.Colname { if !cmd.MissingOk { @@ -51,15 +51,12 @@ func (table *Table) addColumn(cmd *ast.AlterTableCmd) error { } } - table.Columns = append(table.Columns, &Column{ - Name: cmd.Def.Colname, - Type: *cmd.Def.TypeName, - IsNotNull: cmd.Def.IsNotNull, - IsUnsigned: cmd.Def.IsUnsigned, - IsArray: cmd.Def.IsArray, - ArrayDims: cmd.Def.ArrayDims, - Length: cmd.Def.Length, - }) + tc, err := c.defToColumn(table.Rel, cmd.Def) + if err != nil { + return err + } + + table.Columns = append(table.Columns, tc) return nil } @@ -187,7 +184,7 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error { case *ast.AlterTableCmd: switch cmd.Subtype { case ast.AT_AddColumn: - if err := table.addColumn(cmd); err != nil { + if err := table.addColumn(c, cmd); err != nil { return err } case ast.AT_AlterColumnType: @@ -305,25 +302,9 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { continue } - tc := &Column{ - Name: col.Colname, - Type: *col.TypeName, - IsNotNull: col.IsNotNull, - IsUnsigned: col.IsUnsigned, - IsArray: col.IsArray, - ArrayDims: col.ArrayDims, - Comment: col.Comment, - Length: col.Length, - } - if col.Vals != nil { - typeName := ast.TypeName{ - Name: fmt.Sprintf("%s_%s", stmt.Name.Name, col.Colname), - } - s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals} - if err := c.createEnum(s); err != nil { - return err - } - tc.Type = typeName + tc, err := c.defToColumn(stmt.Name, col) + if err != nil { + return err } tbl.Columns = append(tbl.Columns, tc) } @@ -340,6 +321,31 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { return nil } +func (c *Catalog) defToColumn(table *ast.TableName, col *ast.ColumnDef) (*Column, error) { + tc := &Column{ + Name: col.Colname, + Type: *col.TypeName, + IsNotNull: col.IsNotNull, + IsUnsigned: col.IsUnsigned, + IsArray: col.IsArray, + ArrayDims: col.ArrayDims, + Comment: col.Comment, + Length: col.Length, + } + if col.Vals != nil { + typeName := ast.TypeName{ + Name: fmt.Sprintf("%s_%s", table.Name, col.Colname), + } + s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals} + if err := c.createEnum(s, true); err != nil { + return nil, err + } + tc.Type = typeName + } + + return tc, nil +} + func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error { for _, name := range stmt.Tables { ns := name.Schema diff --git a/internal/sql/catalog/types.go b/internal/sql/catalog/types.go index e92a3a219e..2bb1033b3a 100644 --- a/internal/sql/catalog/types.go +++ b/internal/sql/catalog/types.go @@ -3,6 +3,7 @@ package catalog import ( "errors" "fmt" + "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) @@ -61,7 +62,7 @@ func sameType(a, b *ast.TypeName) bool { return true } -func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error { +func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt, overwrite bool) error { ns := stmt.TypeName.Schema if ns == "" { ns = c.DefaultSchema @@ -80,8 +81,18 @@ func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error { if _, _, err := schema.getTable(tbl); err == nil { return sqlerr.RelationExists(tbl.Name) } - if _, _, err := schema.getType(stmt.TypeName); err == nil { - return sqlerr.TypeExists(tbl.Name) + if typ, _, err := schema.getType(stmt.TypeName); err == nil { + if !overwrite { + return sqlerr.TypeExists(tbl.Name) + } + + enum, ok := typ.(*Enum) + if !ok { + return fmt.Errorf("type is not an enum: %s", stmt.TypeName.Name) + } + enum.Vals = stringSlice(stmt.Vals) + + return nil } schema.Types = append(schema.Types, &Enum{ Name: stmt.TypeName.Name,