diff --git a/internal/endtoend/testdata/mysql_enums/go/db.go b/internal/endtoend/testdata/mysql_enums/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mysql_enums/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_enums/go/models.go b/internal/endtoend/testdata/mysql_enums/go/models.go new file mode 100644 index 0000000000..254255896b --- /dev/null +++ b/internal/endtoend/testdata/mysql_enums/go/models.go @@ -0,0 +1,47 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type FirstNameType string + +const ( + john FirstNameType = "john" + albert FirstNameType = "albert" +) + +func (e *FirstNameType) Scan(src interface{}) error { + *e = FirstNameType(src.([]byte)) + return nil +} + +type UserIDType string + +const ( + one UserIDType = "one" + two UserIDType = "two" +) + +func (e *UserIDType) Scan(src interface{}) error { + *e = UserIDType(src.([]byte)) + return nil +} + +type LastNameType string + +const ( + smith LastNameType = "smith" + frank LastNameType = "frank" +) + +func (e *LastNameType) Scan(src interface{}) error { + *e = LastNameType(src.([]byte)) + return nil +} + +type Example struct { + FirstName FirstNameType + UserID UserIDType + LastName LastNameType +} diff --git a/internal/endtoend/testdata/mysql_enums/query.sql b/internal/endtoend/testdata/mysql_enums/query.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/internal/endtoend/testdata/mysql_enums/query.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/internal/endtoend/testdata/mysql_enums/schema.sql b/internal/endtoend/testdata/mysql_enums/schema.sql new file mode 100644 index 0000000000..84724ed2db --- /dev/null +++ b/internal/endtoend/testdata/mysql_enums/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE examples ( + first_name ENUM('john', 'albert') NOT NULL, + user_id ENUM('one', 'two') NOT NULL, + last_name ENUM('smith', 'frank') NOT NULL +) ENGINE=InnoDB; diff --git a/internal/endtoend/testdata/mysql_enums/sqlc.json b/internal/endtoend/testdata/mysql_enums/sqlc.json new file mode 100644 index 0000000000..968e183e2e --- /dev/null +++ b/internal/endtoend/testdata/mysql_enums/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "schema.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/mysql_errors/query.sql b/internal/endtoend/testdata/mysql_errors/query.sql new file mode 100644 index 0000000000..462761beac --- /dev/null +++ b/internal/endtoend/testdata/mysql_errors/query.sql @@ -0,0 +1,26 @@ +/* name: WrongFunc :one */ +select id, first_name from users where id = sqlc.argh(target_id); + +/* name: InvalidName :one */ +select id, first_name from users where id = sqlc.arg(sqlc.arg(target_id)); + +/* name: InvalidVaue :one */ +select id, first_name from users where id = sqlc.arg(?); + +/* name: TooManyFroms :one */ +select id, first_name from users from where id = ?; + +/* name: MisspelledSelect :one */ +selectt id, first_name from users; + +/* name: ExtraSelect :one */ +select id from users where select id; + +-- stderr +-- # package querytest +-- query.sql:1:1: invalid function call "sqlc.argh", did you mean "sqlc.arg"? +-- query.sql:4:1: invalid custom argument value "sqlc.arg(sqlc.arg(target_id))" +-- query.sql:7:1: invalid custom argument value "sqlc.arg(?)" +-- query.sql:11:39: syntax error at or near "from" +-- query.sql:14:9: syntax error at or near "selectt" +-- query.sql:17:35: syntax error at or near "select" diff --git a/internal/endtoend/testdata/mysql_errors/schema.sql b/internal/endtoend/testdata/mysql_errors/schema.sql new file mode 100644 index 0000000000..191528716f --- /dev/null +++ b/internal/endtoend/testdata/mysql_errors/schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL, + job_status ENUM('APPLIED', 'PENDING', 'ACCEPTED', 'REJECTED') NOT NULL +) ENGINE=InnoDB; + +CREATE TABLE orders ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + price DECIMAL(13, 4) NOT NULL, + user_id integer NOT NULL +) ENGINE=InnoDB; + + diff --git a/internal/endtoend/testdata/mysql_errors/sqlc.json b/internal/endtoend/testdata/mysql_errors/sqlc.json new file mode 100644 index 0000000000..a9e7b055a4 --- /dev/null +++ b/internal/endtoend/testdata/mysql_errors/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/mysql_overrides/go/db.go b/internal/endtoend/testdata/mysql_overrides/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mysql_overrides/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_overrides/go/models.go b/internal/endtoend/testdata/mysql_overrides/go/models.go new file mode 100644 index 0000000000..26cfbef5a8 --- /dev/null +++ b/internal/endtoend/testdata/mysql_overrides/go/models.go @@ -0,0 +1,38 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" + + "example.com/mysql" +) + +type JobStatusType string + +const ( + APPLIED JobStatusType = "APPLIED" + PENDING JobStatusType = "PENDING" + ACCEPTED JobStatusType = "ACCEPTED" + REJECTED JobStatusType = "REJECTED" +) + +func (e *JobStatusType) Scan(src interface{}) error { + *e = JobStatusType(src.([]byte)) + return nil +} + +type Order struct { + ID mysql.ID + Price float64 + UserID int +} + +type User struct { + ID mysql.ID + FirstName string + LastName sql.NullString + Age int + JobStatus JobStatusType + Created mysql.Timestamp +} diff --git a/internal/endtoend/testdata/mysql_overrides/go/query.sql.go b/internal/endtoend/testdata/mysql_overrides/go/query.sql.go new file mode 100644 index 0000000000..67e97c297c --- /dev/null +++ b/internal/endtoend/testdata/mysql_overrides/go/query.sql.go @@ -0,0 +1,169 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" + + "example.com/mysql" +) + +const getAll = `-- name: GetAll :many +select id, first_name, last_name, age, job_status, created from users +` + +func (q *Queries) GetAll(ctx context.Context) ([]User, error) { + rows, err := q.db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + &i.JobStatus, + &i.Created, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getAllUsersOrders = `-- name: GetAllUsersOrders :many +select u.id as user_id, u.first_name, o.price, o.id as order_id from orders as o left join users as u on u.id = o.user_id +` + +type GetAllUsersOrdersRow struct { + UserID sql.NullInt64 + FirstName sql.NullString + Price float64 + OrderID int +} + +func (q *Queries) GetAllUsersOrders(ctx context.Context) ([]GetAllUsersOrdersRow, error) { + rows, err := q.db.QueryContext(ctx, getAllUsersOrders) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllUsersOrdersRow + for rows.Next() { + var i GetAllUsersOrdersRow + if err := rows.Scan( + &i.UserID, + &i.FirstName, + &i.Price, + &i.OrderID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getCount = `-- name: GetCount :one +select id as my_id, COUNT(id) as id_count from users where id > 4 +` + +type GetCountRow struct { + MyID int + IDCount int +} + +func (q *Queries) GetCount(ctx context.Context) (GetCountRow, error) { + row := q.db.QueryRowContext(ctx, getCount) + var i GetCountRow + err := row.Scan(&i.MyID, &i.IDCount) + return i, err +} + +const getNameByID = `-- name: GetNameByID :one +select first_name, last_name from users where id = ? +` + +type GetNameByIDRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) GetNameByID(ctx context.Context, id mysql.ID) (GetNameByIDRow, error) { + row := q.db.QueryRowContext(ctx, getNameByID, id) + var i GetNameByIDRow + err := row.Scan(&i.FirstName, &i.LastName) + return i, err +} + +const insertNewUser = `-- name: InsertNewUser :exec +insert into users(first_name, last_name) values (?, ?) +` + +type InsertNewUserParams struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) InsertNewUser(ctx context.Context, arg InsertNewUserParams) error { + _, err := q.db.ExecContext(ctx, insertNewUser, arg.FirstName, arg.LastName) + return err +} + +const insertUsersFromOrders = `-- name: InsertUsersFromOrders :exec +insert into users(first_name) select user_id from orders where id = ? +` + +func (q *Queries) InsertUsersFromOrders(ctx context.Context, id mysql.ID) error { + _, err := q.db.ExecContext(ctx, insertUsersFromOrders, id) + return err +} + +const updateAllUsers = `-- name: UpdateAllUsers :exec +update users set first_name = 'Bob' +` + +func (q *Queries) UpdateAllUsers(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, updateAllUsers) + return err +} + +const updateUserAt = `-- name: UpdateUserAt :exec +update users set first_name = ?, last_name = ? where id > ? and first_name = ? limit 3 +` + +type UpdateUserAtParams struct { + FirstName string + LastName sql.NullString + ID mysql.ID + FirstName_2 string +} + +func (q *Queries) UpdateUserAt(ctx context.Context, arg UpdateUserAtParams) error { + _, err := q.db.ExecContext(ctx, updateUserAt, + arg.FirstName, + arg.LastName, + arg.ID, + arg.FirstName_2, + ) + return err +} diff --git a/internal/endtoend/testdata/mysql_overrides/query.sql b/internal/endtoend/testdata/mysql_overrides/query.sql new file mode 100644 index 0000000000..6f7eca2e92 --- /dev/null +++ b/internal/endtoend/testdata/mysql_overrides/query.sql @@ -0,0 +1,24 @@ +/* name: GetCount :one */ +SELECT id my_id, COUNT(id) id_count FROM users WHERE id > 4; + +/* name: GetNameByID :one */ +SELECT first_name, last_name FROM users WHERE id = ?; + +/* name: GetAll :many */ +SELECT * FROM users; + +/* name: GetAllUsersOrders :many */ +SELECT u.id user_id, u.first_name, o.price, o.id order_id +FROM orders o LEFT JOIN users u ON u.id = o.user_id; + +/* name: InsertNewUser :exec */ +INSERT INTO users (first_name, last_name) VALUES (?, ?); + +/* name: UpdateAllUsers :exec */ +update users set first_name = 'Bob'; + +/* name: UpdateUserAt :exec */ +UPDATE users SET first_name = ?, last_name = ? WHERE id > ? AND first_name = ? LIMIT 3; + +/* name: InsertUsersFromOrders :exec */ +insert into users ( first_name ) select user_id from orders where id = ?; diff --git a/internal/endtoend/testdata/mysql_overrides/schema.sql b/internal/endtoend/testdata/mysql_overrides/schema.sql new file mode 100644 index 0000000000..925637b4c0 --- /dev/null +++ b/internal/endtoend/testdata/mysql_overrides/schema.sql @@ -0,0 +1,16 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL, + job_status ENUM('APPLIED', 'PENDING', 'ACCEPTED', 'REJECTED') NOT NULL, + created TIMESTAMP NOT NULL +) ENGINE=InnoDB; + +CREATE TABLE orders ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + price DECIMAL(13, 4) NOT NULL, + user_id integer NOT NULL +) ENGINE=InnoDB; + + diff --git a/internal/endtoend/testdata/mysql_overrides/sqlc.json b/internal/endtoend/testdata/mysql_overrides/sqlc.json new file mode 100644 index 0000000000..b7cfad6101 --- /dev/null +++ b/internal/endtoend/testdata/mysql_overrides/sqlc.json @@ -0,0 +1,23 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "overrides": [{ + "go_type": "example.com/mysql.ID", + "column": "users.id" + }, { + "go_type": "example.com/mysql.ID", + "column": "orders.id" + }] + } + ], + "overrides": [{ + "go_type": "example.com/mysql.Timestamp", + "db_type": "timestamp" + }] +} diff --git a/internal/endtoend/testdata/mysql_param/go/db.go b/internal/endtoend/testdata/mysql_param/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mysql_param/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_param/go/models.go b/internal/endtoend/testdata/mysql_param/go/models.go new file mode 100644 index 0000000000..f06cf96e3f --- /dev/null +++ b/internal/endtoend/testdata/mysql_param/go/models.go @@ -0,0 +1,35 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type JobStatusType string + +const ( + APPLIED JobStatusType = "APPLIED" + PENDING JobStatusType = "PENDING" + ACCEPTED JobStatusType = "ACCEPTED" + REJECTED JobStatusType = "REJECTED" +) + +func (e *JobStatusType) Scan(src interface{}) error { + *e = JobStatusType(src.([]byte)) + return nil +} + +type Order struct { + ID int + Price float64 + UserID int +} + +type User struct { + ID int + FirstName string + LastName sql.NullString + Age int + JobStatus JobStatusType +} diff --git a/internal/endtoend/testdata/mysql_param/go/query.sql.go b/internal/endtoend/testdata/mysql_param/go/query.sql.go new file mode 100644 index 0000000000..532361e99b --- /dev/null +++ b/internal/endtoend/testdata/mysql_param/go/query.sql.go @@ -0,0 +1,207 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const getUserByID = `-- name: GetUserByID :one +select first_name, id, last_name from users where id = ? +` + +type GetUserByIDRow struct { + FirstName string + ID int + LastName sql.NullString +} + +func (q *Queries) GetUserByID(ctx context.Context, targetID int) (GetUserByIDRow, error) { + row := q.db.QueryRowContext(ctx, getUserByID, targetID) + var i GetUserByIDRow + err := row.Scan(&i.FirstName, &i.ID, &i.LastName) + return i, err +} + +const insertNewUser = `-- name: InsertNewUser :exec +insert into users(first_name, last_name) values (?, ?) +` + +type InsertNewUserParams struct { + FirstName string + UserLastName sql.NullString +} + +func (q *Queries) InsertNewUser(ctx context.Context, arg InsertNewUserParams) error { + _, err := q.db.ExecContext(ctx, insertNewUser, arg.FirstName, arg.UserLastName) + return err +} + +const limitSQLCArg = `-- name: LimitSQLCArg :many +select first_name, id from users limit ? +` + +type LimitSQLCArgRow struct { + FirstName string + ID int +} + +func (q *Queries) LimitSQLCArg(ctx context.Context, UsersLimit uint32) ([]LimitSQLCArgRow, error) { + rows, err := q.db.QueryContext(ctx, limitSQLCArg, UsersLimit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LimitSQLCArgRow + for rows.Next() { + var i LimitSQLCArgRow + if err := rows.Scan(&i.FirstName, &i.ID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listUserOrders = `-- name: ListUserOrders :many +select users.id, users.first_name, orders.price from orders left join users on orders.user_id = users.id where orders.price > ? +` + +type ListUserOrdersRow struct { + ID sql.NullInt64 + FirstName sql.NullString + Price float64 +} + +func (q *Queries) ListUserOrders(ctx context.Context, minPrice float64) ([]ListUserOrdersRow, error) { + rows, err := q.db.QueryContext(ctx, listUserOrders, minPrice) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListUserOrdersRow + for rows.Next() { + var i ListUserOrdersRow + if err := rows.Scan(&i.ID, &i.FirstName, &i.Price); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listUsersByFamily = `-- name: ListUsersByFamily :many +select first_name, last_name from users where age < ? and last_name = ? +` + +type ListUsersByFamilyParams struct { + MaxAge int + InFamily sql.NullString +} + +type ListUsersByFamilyRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) ListUsersByFamily(ctx context.Context, arg ListUsersByFamilyParams) ([]ListUsersByFamilyRow, error) { + rows, err := q.db.QueryContext(ctx, listUsersByFamily, arg.MaxAge, arg.InFamily) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListUsersByFamilyRow + for rows.Next() { + var i ListUsersByFamilyRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listUsersByID = `-- name: ListUsersByID :many +select first_name, id, last_name from users where id < ? +` + +type ListUsersByIDRow struct { + FirstName string + ID int + LastName sql.NullString +} + +func (q *Queries) ListUsersByID(ctx context.Context, id int) ([]ListUsersByIDRow, error) { + rows, err := q.db.QueryContext(ctx, listUsersByID, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListUsersByIDRow + for rows.Next() { + var i ListUsersByIDRow + if err := rows.Scan(&i.FirstName, &i.ID, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listUsersWithLimit = `-- name: ListUsersWithLimit :many +select first_name, last_name from users limit ? +` + +type ListUsersWithLimitRow struct { + FirstName string + LastName sql.NullString +} + +func (q *Queries) ListUsersWithLimit(ctx context.Context, limit uint32) ([]ListUsersWithLimitRow, error) { + rows, err := q.db.QueryContext(ctx, listUsersWithLimit, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListUsersWithLimitRow + for rows.Next() { + var i ListUsersWithLimitRow + if err := rows.Scan(&i.FirstName, &i.LastName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/mysql_param/query.sql b/internal/endtoend/testdata/mysql_param/query.sql new file mode 100644 index 0000000000..48c01d83bb --- /dev/null +++ b/internal/endtoend/testdata/mysql_param/query.sql @@ -0,0 +1,27 @@ +/* name: ListUsersByID :many */ +SELECT first_name, id, last_name FROM users WHERE id < ?; + +/* name: ListUserOrders :many */ +SELECT + users.id, + users.first_name, + orders.price +FROM + orders +LEFT JOIN users ON orders.user_id = users.id +WHERE orders.price > :minPrice; + +/* name: GetUserByID :one */ +SELECT first_name, id, last_name FROM users WHERE id = :targetID; + +/* name: ListUsersByFamily :many */ +SELECT first_name, last_name FROM users WHERE age < :maxAge AND last_name = :inFamily; + +/* name: ListUsersWithLimit :many */ +SELECT first_name, last_name FROM users LIMIT ?; + +/* name: LimitSQLCArg :many */ +select first_name, id FROM users LIMIT sqlc.arg(UsersLimit); + +/* name: InsertNewUser :exec */ +INSERT INTO users (first_name, last_name) VALUES (?, sqlc.arg(user_last_name)); diff --git a/internal/endtoend/testdata/mysql_param/schema.sql b/internal/endtoend/testdata/mysql_param/schema.sql new file mode 100644 index 0000000000..191528716f --- /dev/null +++ b/internal/endtoend/testdata/mysql_param/schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL, + job_status ENUM('APPLIED', 'PENDING', 'ACCEPTED', 'REJECTED') NOT NULL +) ENGINE=InnoDB; + +CREATE TABLE orders ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + price DECIMAL(13, 4) NOT NULL, + user_id integer NOT NULL +) ENGINE=InnoDB; + + diff --git a/internal/endtoend/testdata/mysql_param/sqlc.json b/internal/endtoend/testdata/mysql_param/sqlc.json new file mode 100644 index 0000000000..a9e7b055a4 --- /dev/null +++ b/internal/endtoend/testdata/mysql_param/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql" + } + ] +} diff --git a/internal/endtoend/testdata/mysql_parse/go/db.go b/internal/endtoend/testdata/mysql_parse/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/mysql_parse/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_parse/go/models.go b/internal/endtoend/testdata/mysql_parse/go/models.go new file mode 100644 index 0000000000..f2e2656298 --- /dev/null +++ b/internal/endtoend/testdata/mysql_parse/go/models.go @@ -0,0 +1,35 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type JobStatusType string + +const ( + APPLIED JobStatusType = "APPLIED" + PENDING JobStatusType = "PENDING" + ACCEPTED JobStatusType = "ACCEPTED" + REJECTED JobStatusType = "REJECTED" +) + +func (e *JobStatusType) Scan(src interface{}) error { + *e = JobStatusType(src.([]byte)) + return nil +} + +type Order struct { + ID int `json:"id"` + Price float64 `json:"price"` + UserID int `json:"user_id"` +} + +type User struct { + ID int `json:"id"` + FirstName string `json:"first_name"` + LastName sql.NullString `json:"last_name"` + Age int `json:"age"` + JobStatus JobStatusType `json:"job_status"` +} diff --git a/internal/endtoend/testdata/mysql_parse/go/query.sql.go b/internal/endtoend/testdata/mysql_parse/go/query.sql.go new file mode 100644 index 0000000000..fba91b668c --- /dev/null +++ b/internal/endtoend/testdata/mysql_parse/go/query.sql.go @@ -0,0 +1,166 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const getAll = `-- name: GetAll :many +select id, first_name, last_name, age, job_status from users +` + +func (q *Queries) GetAll(ctx context.Context) ([]User, error) { + rows, err := q.db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + &i.JobStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getAllUsersOrders = `-- name: GetAllUsersOrders :many +select u.id as user_id, u.first_name, o.price, o.id as order_id from orders as o left join users as u on u.id = o.user_id +` + +type GetAllUsersOrdersRow struct { + UserID sql.NullInt64 `json:"user_id"` + FirstName sql.NullString `json:"first_name"` + Price float64 `json:"price"` + OrderID int `json:"order_id"` +} + +func (q *Queries) GetAllUsersOrders(ctx context.Context) ([]GetAllUsersOrdersRow, error) { + rows, err := q.db.QueryContext(ctx, getAllUsersOrders) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllUsersOrdersRow + for rows.Next() { + var i GetAllUsersOrdersRow + if err := rows.Scan( + &i.UserID, + &i.FirstName, + &i.Price, + &i.OrderID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getCount = `-- name: GetCount :one +select id as my_id, COUNT(id) as id_count from users where id > 4 +` + +type GetCountRow struct { + MyID int `json:"my_id"` + IDCount int `json:"id_count"` +} + +func (q *Queries) GetCount(ctx context.Context) (GetCountRow, error) { + row := q.db.QueryRowContext(ctx, getCount) + var i GetCountRow + err := row.Scan(&i.MyID, &i.IDCount) + return i, err +} + +const getNameByID = `-- name: GetNameByID :one +select first_name, last_name from users where id = ? +` + +type GetNameByIDRow struct { + FirstName string `json:"first_name"` + LastName sql.NullString `json:"last_name"` +} + +func (q *Queries) GetNameByID(ctx context.Context, id int) (GetNameByIDRow, error) { + row := q.db.QueryRowContext(ctx, getNameByID, id) + var i GetNameByIDRow + err := row.Scan(&i.FirstName, &i.LastName) + return i, err +} + +const insertNewUser = `-- name: InsertNewUser :exec +insert into users(first_name, last_name) values (?, ?) +` + +type InsertNewUserParams struct { + FirstName string `json:"first_name"` + LastName sql.NullString `json:"last_name"` +} + +func (q *Queries) InsertNewUser(ctx context.Context, arg InsertNewUserParams) error { + _, err := q.db.ExecContext(ctx, insertNewUser, arg.FirstName, arg.LastName) + return err +} + +const insertUsersFromOrders = `-- name: InsertUsersFromOrders :exec +insert into users(first_name) select user_id from orders where id = ? +` + +func (q *Queries) InsertUsersFromOrders(ctx context.Context, id int) error { + _, err := q.db.ExecContext(ctx, insertUsersFromOrders, id) + return err +} + +const updateAllUsers = `-- name: UpdateAllUsers :exec +update users set first_name = 'Bob' +` + +func (q *Queries) UpdateAllUsers(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, updateAllUsers) + return err +} + +const updateUserAt = `-- name: UpdateUserAt :exec +update users set first_name = ?, last_name = ? where id > ? and first_name = ? limit 3 +` + +type UpdateUserAtParams struct { + FirstName string `json:"first_name"` + LastName sql.NullString `json:"last_name"` + ID int `json:"id"` + FirstName_2 string `json:"first_name_2"` +} + +func (q *Queries) UpdateUserAt(ctx context.Context, arg UpdateUserAtParams) error { + _, err := q.db.ExecContext(ctx, updateUserAt, + arg.FirstName, + arg.LastName, + arg.ID, + arg.FirstName_2, + ) + return err +} diff --git a/internal/endtoend/testdata/mysql_parse/query.sql b/internal/endtoend/testdata/mysql_parse/query.sql new file mode 100644 index 0000000000..6f7eca2e92 --- /dev/null +++ b/internal/endtoend/testdata/mysql_parse/query.sql @@ -0,0 +1,24 @@ +/* name: GetCount :one */ +SELECT id my_id, COUNT(id) id_count FROM users WHERE id > 4; + +/* name: GetNameByID :one */ +SELECT first_name, last_name FROM users WHERE id = ?; + +/* name: GetAll :many */ +SELECT * FROM users; + +/* name: GetAllUsersOrders :many */ +SELECT u.id user_id, u.first_name, o.price, o.id order_id +FROM orders o LEFT JOIN users u ON u.id = o.user_id; + +/* name: InsertNewUser :exec */ +INSERT INTO users (first_name, last_name) VALUES (?, ?); + +/* name: UpdateAllUsers :exec */ +update users set first_name = 'Bob'; + +/* name: UpdateUserAt :exec */ +UPDATE users SET first_name = ?, last_name = ? WHERE id > ? AND first_name = ? LIMIT 3; + +/* name: InsertUsersFromOrders :exec */ +insert into users ( first_name ) select user_id from orders where id = ?; diff --git a/internal/endtoend/testdata/mysql_parse/schema.sql b/internal/endtoend/testdata/mysql_parse/schema.sql new file mode 100644 index 0000000000..191528716f --- /dev/null +++ b/internal/endtoend/testdata/mysql_parse/schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL, + job_status ENUM('APPLIED', 'PENDING', 'ACCEPTED', 'REJECTED') NOT NULL +) ENGINE=InnoDB; + +CREATE TABLE orders ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + price DECIMAL(13, 4) NOT NULL, + user_id integer NOT NULL +) ENGINE=InnoDB; + + diff --git a/internal/endtoend/testdata/mysql_parse/sqlc.json b/internal/endtoend/testdata/mysql_parse/sqlc.json new file mode 100644 index 0000000000..4c2bb8a336 --- /dev/null +++ b/internal/endtoend/testdata/mysql_parse/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "emit_json_tags": true + } + ] +} diff --git a/internal/mysql/errors_test.go b/internal/mysql/errors_test.go deleted file mode 100644 index 588f19f401..0000000000 --- a/internal/mysql/errors_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package mysql - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "vitess.io/vitess/go/vt/sqlparser" - - "github.com/kyleconroy/sqlc/internal/config" -) - -func TestCustomArgErr(t *testing.T) { - tests := [...]struct { - input string - output sqlparser.PositionedErr - }{ - { - input: "/* name: GetUser :one */\nselect id, first_name from users where id = sqlc.argh(target_id)", - output: sqlparser.PositionedErr{ - Err: `invalid function call "sqlc.argh", did you mean "sqlc.arg"?`, - Pos: 0, - Near: nil, - }, - }, - { - input: "/* name: GetUser :one */\nselect id, first_name from users where id = sqlc.arg(sqlc.arg(target_id))", - output: sqlparser.PositionedErr{ - Err: `invalid custom argument value "sqlc.arg(sqlc.arg(target_id))"`, - Pos: 0, - Near: nil, - }, - }, - { - input: "/* name: GetUser :one */\nselect id, first_name from users where id = sqlc.arg(?)", - output: sqlparser.PositionedErr{ - Err: `invalid custom argument value "sqlc.arg(?)"`, - Pos: 0, - Near: nil, - }, - }, - } - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - generator := PackageGenerator{mockSchema, settings, "db"} - for _, tcase := range tests { - q, err := generator.parseContents("queries.sql", tcase.input) - if err == nil && len(q) > 0 { - t.Errorf("parse contents succeeded on an invalid query") - } - if diff := cmp.Diff(tcase.output, err); diff != "" { - t.Errorf(diff) - } - } -} - -func TestPositionedErr(t *testing.T) { - tests := [...]struct { - input string - output sqlparser.PositionedErr - }{ - { - input: "/* name: GetUser :one */\nselect id, first_name from users from where id = ?", - output: sqlparser.PositionedErr{ - Err: `syntax error`, - Pos: 63, - Near: []byte("from"), - }, - }, - { - input: "/* name: GetUser :one */\nselectt id, first_name from users", - output: sqlparser.PositionedErr{ - Err: `syntax error`, - Pos: 33, - Near: []byte("selectt"), - }, - }, - { - input: "/* name: GetUser :one */\nselect id from users where select id", - output: sqlparser.PositionedErr{ - Err: `syntax error`, - Pos: 59, - Near: []byte("select"), - }, - }, - } - - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - for _, tcase := range tests { - generator := PackageGenerator{mockSchema, settings, "db"} - q, err := generator.parseContents("queries.sql", tcase.input) - if err == nil && len(q) > 0 { - t.Errorf("parse contents succeeded on an invalid query") - } - if diff := cmp.Diff(tcase.output, err); diff != "" { - t.Errorf(diff) - } - } -} diff --git a/internal/mysql/gen_test.go b/internal/mysql/gen_test.go index 8856c24357..1d00b4c155 100644 --- a/internal/mysql/gen_test.go +++ b/internal/mysql/gen_test.go @@ -4,11 +4,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "vitess.io/vitess/go/vt/sqlparser" - - "github.com/kyleconroy/sqlc/internal/config" - "github.com/kyleconroy/sqlc/internal/dinosql" - "github.com/kyleconroy/sqlc/internal/pg" ) func TestArgName(t *testing.T) { @@ -37,155 +32,3 @@ func TestArgName(t *testing.T) { } } } -func TestEnumName(t *testing.T) { - tcase := [...]struct { - input sqlparser.ColumnDefinition - output string - }{ - { - input: sqlparser.ColumnDefinition{Name: sqlparser.NewColIdent("first_name")}, - output: "FirstNameType", - }, - { - input: sqlparser.ColumnDefinition{Name: sqlparser.NewColIdent("user_id")}, - output: "UserIDType", - }, - { - input: sqlparser.ColumnDefinition{Name: sqlparser.NewColIdent("last_name")}, - output: "LastNameType", - }, - } - - generator := PackageGenerator{mockSchema, config.CombinedSettings{}, ""} - for _, tc := range tcase { - enumName := generator.enumNameFromColDef(&tc.input) - if diff := cmp.Diff(enumName, tc.output); diff != "" { - t.Errorf(diff) - } - } -} - -func TestEnums(t *testing.T) { - generator := PackageGenerator{mockSchema, config.CombinedSettings{}, ""} - tcase := [...]struct { - input Result - output []dinosql.GoEnum - }{ - { - input: Result{PackageGenerator: generator}, - output: []dinosql.GoEnum{ - { - Name: "JobStatusType", - Constants: []dinosql.GoConstant{ - {Name: "applied", Type: "JobStatusType", Value: "applied"}, - {Name: "pending", Type: "JobStatusType", Value: "pending"}, - {Name: "accepted", Type: "JobStatusType", Value: "accepted"}, - {Name: "rejected", Type: "JobStatusType", Value: "rejected"}, - }, - }, - }, - }, - } - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - for _, tc := range tcase { - enums := tc.input.Enums(settings) - if diff := cmp.Diff(enums, tc.output); diff != "" { - t.Errorf(diff) - } - } -} - -func TestStructs(t *testing.T) { - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - generator := PackageGenerator{mockSchema, settings, "db"} - tcase := [...]struct { - input Result - output []dinosql.GoStruct - }{ - { - input: Result{PackageGenerator: generator}, - output: []dinosql.GoStruct{ - { - Table: pg.FQN{Catalog: "orders"}, - Name: "Order", - Fields: []dinosql.GoField{ - {Name: "ID", Type: "int", Tags: map[string]string{"json:": "id"}}, - {Name: "Price", Type: "float64", Tags: map[string]string{"json:": "price"}}, - {Name: "UserID", Type: "int", Tags: map[string]string{"json:": "user_id"}}, - }, - }, - { - Table: pg.FQN{Catalog: "users"}, - Name: "User", - Fields: []dinosql.GoField{ - {Name: "FirstName", Type: "string", Tags: map[string]string{"json:": "first_name"}}, - {Name: "LastName", Type: "sql.NullString", Tags: map[string]string{"json:": "last_name"}}, - {Name: "ID", Type: "int", Tags: map[string]string{"json:": "id"}}, - {Name: "Age", Type: "int", Tags: map[string]string{"json:": "age"}}, - {Name: "JobStatus", Type: "JobStatusType", Tags: map[string]string{"json:": "job_status"}}, - }}, - }, - }, - } - - for _, tc := range tcase { - structs := tc.input.Structs(settings) - if diff := cmp.Diff(structs, tc.output); diff != "" { - t.Errorf(diff) - } - } -} - -func TestTypeOverride(t *testing.T) { - tests := [...]struct { - overrides []config.Override - col Column - expectedGoType string - }{ - { - overrides: []config.Override{ - { - DBType: "uuid", - GoTypeName: "KSUID", // this is populated by the dinosql.Parse - }, - }, - col: Column{ - ColumnDefinition: &sqlparser.ColumnDefinition{ - Type: sqlparser.ColumnType{ - Type: "uuid", - NotNull: true, - }, - }, - }, - expectedGoType: "KSUID", - }, - { - overrides: []config.Override{ - { - ColumnName: "user_id", // this is populated by dinosql.Parse - GoTypeName: "uuid", // this is populated by dinosql.Parse - }, - }, - col: Column{ - ColumnDefinition: &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("user_id"), - Type: sqlparser.ColumnType{ - Type: "varchar", - NotNull: true, - }, - }, - }, - expectedGoType: "uuid", - }, - } - - for _, tcase := range tests { - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{Overrides: tcase.overrides}) - gen := PackageGenerator{mockSchema, settings, "db"} - goType := gen.goTypeCol(tcase.col) - - if diff := cmp.Diff(tcase.expectedGoType, goType); diff != "" { - t.Errorf(diff) - } - } -} diff --git a/internal/mysql/param_test.go b/internal/mysql/param_test.go deleted file mode 100644 index de0c39a672..0000000000 --- a/internal/mysql/param_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package mysql - -import ( - "reflect" - "testing" - - "github.com/davecgh/go-spew/spew" - "vitess.io/vitess/go/vt/sqlparser" - - "github.com/kyleconroy/sqlc/internal/config" -) - -func TestSelectParamSearcher(t *testing.T) { - type testCase struct { - input string - output []*Param - } - - tests := []testCase{ - testCase{ - input: "SELECT first_name, id, last_name FROM users WHERE id < ?", - output: []*Param{&Param{ - OriginalName: ":v1", - Name: "id", - Typ: "int", - }, - }, - }, - testCase{ - input: `SELECT - users.id, - users.first_name, - orders.price - FROM - orders - LEFT JOIN users ON orders.user_id = users.id - WHERE orders.price > :minPrice`, - output: []*Param{ - &Param{ - OriginalName: ":minPrice", - Name: "minPrice", - Typ: "float64", - }, - }, - }, - testCase{ - input: "SELECT first_name, id, last_name FROM users WHERE id = :targetID", - output: []*Param{&Param{ - OriginalName: ":targetID", - Name: "targetID", - Typ: "int", - }, - }, - }, - testCase{ - input: "SELECT first_name, last_name FROM users WHERE age < :maxAge AND last_name = :inFamily", - output: []*Param{ - &Param{ - OriginalName: ":maxAge", - Name: "maxAge", - Typ: "int", - }, - &Param{ - OriginalName: ":inFamily", - Name: "inFamily", - Typ: "sql.NullString", - }, - }, - }, - testCase{ - input: "SELECT first_name, last_name FROM users LIMIT ?", - output: []*Param{ - &Param{ - OriginalName: ":v1", - Name: "limit", - Typ: "uint32", - }, - }, - }, - { - input: "select first_name, id FROM users LIMIT sqlc.arg(UsersLimit)", - output: []*Param{ - &Param{ - OriginalName: "sqlc.arg(UsersLimit)", - Name: "UsersLimit", - Typ: "uint32", - }, - }, - }, - } - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - for _, tCase := range tests { - generator := PackageGenerator{ - Schema: mockSchema, - CombinedSettings: settings, - packageName: "db", - } - tree, err := sqlparser.Parse(tCase.input) - if err != nil { - t.Errorf("Failed to parse input query") - } - selectStm, ok := tree.(*sqlparser.Select) - - tableAliasMap, _, err := parseFrom(selectStm.From, false) - if err != nil { - t.Errorf("Failed to parse table name alias's: %v", err) - } - - limitParams, err := generator.paramsInLimitExpr(selectStm.Limit, tableAliasMap) - if err != nil { - t.Errorf("Failed to parse limit expression params: %v", err) - } - whereParams, err := generator.paramsInWhereExpr(selectStm.Where, tableAliasMap, "users") - if err != nil { - t.Errorf("Failed to parse where expression params: %v", err) - } - - params := append(limitParams, whereParams...) - if !ok { - t.Errorf("Test case is not SELECT statement as expected") - } - - if !reflect.DeepEqual(params, tCase.output) { - t.Errorf("Param searcher returned unexpected result\nResult: %v\nExpected: %v", - spew.Sdump(params), spew.Sdump(tCase.output)) - } - } -} - -func TestInsertParamSearcher(t *testing.T) { - type testCase struct { - input string - output []*Param - expectedNames []string - } - - tests := []testCase{ - testCase{ - input: "/* name: InsertNewUser :exec */\nINSERT INTO users (first_name, last_name) VALUES (?, sqlc.arg(user_last_name))", - output: []*Param{ - &Param{ - OriginalName: ":v1", - Name: "first_name", - Typ: "string", - }, - &Param{ - OriginalName: "sqlc.arg(user_last_name)", - Name: "user_last_name", - Typ: "sql.NullString", - }, - }, - expectedNames: []string{"first_name", "user_last_name"}, - }, - } - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - for _, tCase := range tests { - generator := PackageGenerator{ - Schema: mockSchema, - CombinedSettings: settings, - packageName: "db", - } - tree, err := sqlparser.Parse(tCase.input) - if err != nil { - t.Errorf("Failed to parse input query") - } - insertStm, ok := tree.(*sqlparser.Insert) - if !ok { - t.Errorf("Test case is not SELECT statement as expected") - } - result, err := generator.parseInsert(insertStm, tCase.input) - - if err != nil { - t.Errorf("Failed to parse insert statement.") - } - - if !reflect.DeepEqual(result.Params, tCase.output) { - t.Errorf("Param searcher returned unexpected result\nResult: %v\nExpected: %v\nQuery: %s", - spew.Sdump(result.Params), spew.Sdump(tCase.output), tCase.input) - } - if len(result.Params) != len(tCase.expectedNames) { - t.Errorf("Insufficient test cases. Mismatch in length of expected param names and parsed params") - } - for ix, p := range result.Params { - if p.Name != tCase.expectedNames[ix] { - t.Errorf("Derived param does not match expected output.\nResult: %v\nExpected: %v", - p.Name, tCase.expectedNames[ix]) - } - } - } -} diff --git a/internal/mysql/parse.go b/internal/mysql/parse.go index f1ddf6fdf6..ee1404905d 100644 --- a/internal/mysql/parse.go +++ b/internal/mysql/parse.go @@ -49,20 +49,39 @@ func parsePath(sqlPath string, generator PackageGenerator) (*Result, error) { parseErrors.Add(filename, "", 0, err) continue } - queries, err := generator.parseContents(filename, contents) - if err != nil { - if posErr, ok := err.(sqlparser.PositionedErr); ok { - message := fmt.Errorf(posErr.Err) - if posErr.Near != nil { - message = fmt.Errorf("%s at or near \"%s\"", posErr.Err, posErr.Near) + + t := sqlparser.NewStringTokenizer(contents) + var start int + for { + q, err := sqlparser.ParseNextStrictDDL(t) + if err == io.EOF { + break + } else if err != nil { + if posErr, ok := err.(sqlparser.PositionedErr); ok { + message := fmt.Errorf(posErr.Err) + if posErr.Near != nil { + message = fmt.Errorf("%s at or near \"%s\"", posErr.Err, posErr.Near) + } + parseErrors.Add(filename, contents, posErr.Pos, message) + } else { + parseErrors.Add(filename, contents, start, err) } - parseErrors.Add(filename, contents, posErr.Pos, message) - } else { - parseErrors.Add(filename, contents, 0, err) + continue } - continue + query := contents[start : t.Position-1] + result, err := generator.parseQueryString(q, query) + if err != nil { + parseErrors.Add(filename, contents, start, err) + start = t.Position + continue + } + start = t.Position + if result == nil { + continue + } + result.Filename = filepath.Base(filename) + parsedQueries = append(parsedQueries, result) } - parsedQueries = append(parsedQueries, queries...) } if len(parseErrors.Errs) > 0 { @@ -75,32 +94,6 @@ func parsePath(sqlPath string, generator PackageGenerator) (*Result, error) { }, nil } -func (pGen *PackageGenerator) parseContents(filename, contents string) ([]*Query, error) { - t := sqlparser.NewStringTokenizer(contents) - var queries []*Query - var start int - for { - q, err := sqlparser.ParseNextStrictDDL(t) - if err == io.EOF { - break - } else if err != nil { - return nil, err - } - query := contents[start : t.Position-1] - result, err := pGen.parseQueryString(q, query) - if err != nil { - return nil, sqlparser.PositionedErr{Err: err.Error(), Pos: start, Near: nil} - } - start = t.Position - if result == nil { - continue - } - result.Filename = filepath.Base(filename) - queries = append(queries, result) - } - return queries, nil -} - func (pGen PackageGenerator) parseQueryString(tree sqlparser.Statement, query string) (*Query, error) { var parsedQuery *Query switch tree := tree.(type) { diff --git a/internal/mysql/parse_test.go b/internal/mysql/parse_test.go deleted file mode 100644 index 6486da2a95..0000000000 --- a/internal/mysql/parse_test.go +++ /dev/null @@ -1,450 +0,0 @@ -package mysql - -import ( - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/kyleconroy/sqlc/internal/config" - "github.com/kyleconroy/sqlc/internal/dinosql" - "vitess.io/vitess/go/vt/sqlparser" -) - -func init() { - initMockSchema() -} - -var mockSchema *Schema - -func initMockSchema() { - var schemaMap = make(map[string][]*sqlparser.ColumnDefinition) - mockSchema = &Schema{ - tables: schemaMap, - } - schemaMap["users"] = []*sqlparser.ColumnDefinition{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("first_name"), - Type: sqlparser.ColumnType{ - Type: "varchar", - NotNull: true, - }, - }, - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("last_name"), - Type: sqlparser.ColumnType{ - Type: "varchar", - NotNull: false, - }, - }, - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("id"), - Type: sqlparser.ColumnType{ - Type: "int", - NotNull: true, - Autoincrement: true, - }, - }, - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("age"), - Type: sqlparser.ColumnType{ - Type: "int", - NotNull: true, - }, - }, - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("job_status"), - Type: sqlparser.ColumnType{ - Type: "enum", - NotNull: true, - EnumValues: []string{"applied", "pending", "accepted", "rejected"}, - }, - }, - } - schemaMap["orders"] = []*sqlparser.ColumnDefinition{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("id"), - Type: sqlparser.ColumnType{ - Type: "int", - NotNull: true, - Autoincrement: true, - }, - }, - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("price"), - Type: sqlparser.ColumnType{ - Type: "DECIMAL(13, 4)", - NotNull: true, - Autoincrement: true, - }, - }, - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("user_id"), - Type: sqlparser.ColumnType{ - Type: "int", - NotNull: true, - }, - }, - } -} - -func filterCols(allCols []*sqlparser.ColumnDefinition, colNames map[string]string) []Column { - cols := []Column{} - for _, col := range allCols { - if table, ok := colNames[col.Name.String()]; ok { - cols = append(cols, Column{ - col, - table, - }) - } - } - return cols -} - -func TestParseSelect(t *testing.T) { - type expected struct { - query string - schema *Schema - } - type testCase struct { - name string - input expected - output *Query - } - tests := []testCase{ - testCase{ - name: "get_count", - input: expected{ - query: `/* name: GetCount :one */ - SELECT id my_id, COUNT(id) id_count FROM users WHERE id > 4`, - schema: mockSchema, - }, - output: &Query{ - SQL: "select id as my_id, COUNT(id) as id_count from users where id > 4", - Columns: []Column{ - Column{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("my_id"), - Type: sqlparser.ColumnType{ - Type: "int", - NotNull: true, - Autoincrement: true, - }, - }, - "users", - }, - Column{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("id_count"), - Type: sqlparser.ColumnType{ - Type: "int", - NotNull: true, - }, - }, - "", - }, - }, - Params: []*Param{}, - Name: "GetCount", - Cmd: ":one", - DefaultTableName: "users", - }, - }, - testCase{ - name: "get_name_by_id", - input: expected{ - query: `/* name: GetNameByID :one */ - SELECT first_name, last_name FROM users WHERE id = ?`, - schema: mockSchema, - }, - output: &Query{ - SQL: `select first_name, last_name from users where id = ?`, - Columns: filterCols(mockSchema.tables["users"], map[string]string{"first_name": "users", "last_name": "users"}), - Params: []*Param{ - &Param{ - OriginalName: ":v1", - Name: "id", - Typ: "int", - }}, - Name: "GetNameByID", - Cmd: ":one", - DefaultTableName: "users", - }, - }, - testCase{ - name: "get_all", - input: expected{ - query: `/* name: GetAll :many */ - SELECT * FROM users;`, - schema: mockSchema, - }, - output: &Query{ - SQL: "select first_name, last_name, id, age, job_status from users", - Columns: filterCols(mockSchema.tables["users"], map[string]string{"first_name": "users", "last_name": "users", "id": "users", "age": "users", "job_status": "users"}), - Params: []*Param{}, - Name: "GetAll", - Cmd: ":many", - DefaultTableName: "users", - }, - }, - testCase{ - name: "get_all_users_orders", - input: expected{ - query: `/* name: GetAllUsersOrders :many */ - SELECT u.id user_id, u.first_name, o.price, o.id order_id - FROM orders o LEFT JOIN users u ON u.id = o.user_id`, - schema: mockSchema, - }, - output: &Query{ - SQL: "select u.id as user_id, u.first_name, o.price, o.id as order_id from orders as o left join users as u on u.id = o.user_id", - Columns: []Column{ - Column{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("user_id"), - Type: sqlparser.ColumnType{ - Type: "int", - Autoincrement: true, - NotNull: false, // beause of the left join - }, - }, - "users", - }, - Column{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("first_name"), - Type: sqlparser.ColumnType{ - Type: "varchar", - NotNull: false, // because of left join - }, - }, - "users", - }, - Column{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("price"), - Type: sqlparser.ColumnType{ - Type: "DECIMAL(13, 4)", - Autoincrement: true, - NotNull: true, - }, - }, - "orders", - }, - Column{ - &sqlparser.ColumnDefinition{ - Name: sqlparser.NewColIdent("order_id"), - Type: sqlparser.ColumnType{ - Type: "int", - Autoincrement: true, - NotNull: true, - }, - }, - "orders", - }, - }, - Params: []*Param{}, - Name: "GetAllUsersOrders", - Cmd: ":many", - DefaultTableName: "orders", - }, - }, - } - - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - for _, tt := range tests { - testCase := tt - generator := PackageGenerator{ - Schema: testCase.input.schema, - CombinedSettings: settings, - packageName: "db", - } - t.Run(tt.name, func(t *testing.T) { - qs, err := generator.parseContents("example.sql", testCase.input.query) - if err != nil { - t.Fatalf("Parsing failed with query: [%v]\n", err) - } - if len(qs) != 1 { - t.Fatalf("Expected one query, not %d", len(qs)) - } - q := qs[0] - q.Filename = "" - if diff := cmp.Diff(testCase.output, q); diff != "" { - t.Errorf("parsed query differs: \n%s", diff) - } - }) - } -} - -func TestParseLeadingComment(t *testing.T) { - type output struct { - Name string - Cmd string - } - tests := []struct { - input string - output output - }{ - { - input: "/* name: GetPeopleByID :many */", - output: output{Name: "GetPeopleByID", Cmd: ":many"}, - }, - } - - for _, tCase := range tests { - name, cmd, err := dinosql.ParseMetadata(tCase.input, dinosql.CommentSyntaxStar) - result := output{name, cmd} - if err != nil { - t.Errorf("failed to parse leading comment: %w", err) - } else if diff := cmp.Diff(tCase.output, result); diff != "" { - t.Errorf("unexpectd result of query metadata parse: %s", diff) - } - } -} - -func TestSchemaLookup(t *testing.T) { - firstNameColDfn, err := mockSchema.schemaLookup("users", "first_name") - if err != nil { - t.Errorf("Failed to get column schema from mock schema: %v", err) - } - - expected := filterCols(mockSchema.tables["users"], map[string]string{"first_name": "users"}) - if !reflect.DeepEqual(*firstNameColDfn, expected[0]) { - t.Errorf("Table schema lookup returned unexpected result") - } -} - -func TestParseInsertUpdate(t *testing.T) { - type expected struct { - query string - schema *Schema - } - type testCase struct { - name string - input expected - output *Query - } - - tests := []testCase{ - testCase{ - name: "insert_users", - input: expected{ - query: "/* name: InsertNewUser :exec */\nINSERT INTO users (first_name, last_name) VALUES (?, ?)", - schema: mockSchema, - }, - output: &Query{ - SQL: "insert into users(first_name, last_name) values (?, ?)", - Columns: nil, - Params: []*Param{ - &Param{ - OriginalName: ":v1", - Name: "first_name", - Typ: "string", - }, - &Param{ - OriginalName: ":v2", - Name: "last_name", - Typ: "sql.NullString", - }, - }, - Name: "InsertNewUser", - Cmd: ":exec", - DefaultTableName: "users", - }, - }, - testCase{ - name: "update_without_where", - input: expected{ - query: "/* name: UpdateAllUsers :exec */ update users set first_name = 'Bob'", - schema: mockSchema, - }, - output: &Query{ - SQL: "update users set first_name = 'Bob'", - Columns: nil, - Params: []*Param{}, - Name: "UpdateAllUsers", - Cmd: ":exec", - DefaultTableName: "users", - }, - }, - testCase{ - name: "update_users", - input: expected{ - query: "/* name: UpdateUserAt :exec */\nUPDATE users SET first_name = ?, last_name = ? WHERE id > ? AND first_name = ? LIMIT 3", - schema: mockSchema, - }, - output: &Query{ - SQL: "update users set first_name = ?, last_name = ? where id > ? and first_name = ? limit 3", - Columns: nil, - Params: []*Param{ - &Param{ - OriginalName: ":v1", - Name: "first_name", - Typ: "string", - }, - &Param{ - OriginalName: ":v2", - Name: "last_name", - Typ: "sql.NullString", - }, - &Param{ - OriginalName: ":v3", - Name: "id", - Typ: "int", - }, - &Param{ - OriginalName: ":v4", - Name: "first_name", - Typ: "string", - }, - }, - Name: "UpdateUserAt", - Cmd: ":exec", - DefaultTableName: "users", - }, - }, - testCase{ - name: "insert_users_from_orders", - input: expected{ - query: "/* name: InsertUsersFromOrders :exec */\ninsert into users ( first_name ) select user_id from orders where id = ?;", - schema: mockSchema, - }, - output: &Query{ - SQL: "insert into users(first_name) select user_id from orders where id = ?", - Columns: nil, - Params: []*Param{ - &Param{ - OriginalName: ":v1", - Name: "id", - Typ: "int", - }, - }, - Name: "InsertUsersFromOrders", - Cmd: ":exec", - DefaultTableName: "users", - }, - }, - } - - settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{}) - for _, tt := range tests { - testCase := tt - t.Run(tt.name, func(t *testing.T) { - generator := PackageGenerator{ - Schema: testCase.input.schema, - CombinedSettings: settings, - packageName: "db", - } - qs, err := generator.parseContents("example.sql", testCase.input.query) - if err != nil { - t.Fatalf("Parsing failed with query: [%v]\n", err) - } - if len(qs) != 1 { - t.Fatalf("Expected one query, not %d", len(qs)) - } - q := qs[0] - q.Filename = "" - if diff := cmp.Diff(testCase.output, q); diff != "" { - t.Errorf("parsed query differs: \n%s", diff) - } - }) - } -}