diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 6ced89ae23..0ee3b3e18a 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -550,12 +550,15 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro if err != nil { continue } - table, err := qc.GetTable(&ast.TableName{ - Catalog: fn.ReturnType.Catalog, - Schema: fn.ReturnType.Schema, - Name: fn.ReturnType.Name, - }) - if err != nil { + var table *Table + if fn.ReturnType != nil { + table, err = qc.GetTable(&ast.TableName{ + Catalog: fn.ReturnType.Catalog, + Schema: fn.ReturnType.Schema, + Name: fn.ReturnType.Name, + }) + } + if table == nil || err != nil { if n.Alias != nil && len(n.Alias.Colnames.Items) > 0 { table = &Table{} for _, colName := range n.Alias.Colnames.Items { @@ -575,12 +578,22 @@ func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, erro Schema: fn.Rel.Schema, Name: fn.Rel.Name, }, - Columns: []*Column{ + } + if len(fn.Outs) > 0 { + for _, arg := range fn.Outs { + table.Columns = append(table.Columns, &Column{ + Name: arg.Name, + DataType: arg.Type.Name, + }) + } + } + if fn.ReturnType != nil { + table.Columns = []*Column{ { Name: colName, DataType: fn.ReturnType.Name, }, - }, + } } } } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index 117cf44813..e3768de251 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -2,11 +2,13 @@ package compiler import ( "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" ) type Function struct { Rel *ast.FuncName ReturnType *ast.TypeName + Outs []*catalog.Argument } type Table struct { diff --git a/internal/compiler/query_catalog.go b/internal/compiler/query_catalog.go index 42088446d2..80b59d876c 100644 --- a/internal/compiler/query_catalog.go +++ b/internal/compiler/query_catalog.go @@ -103,6 +103,7 @@ func (qc QueryCatalog) GetFunc(rel *ast.FuncName) (*Function, error) { } return &Function{ Rel: rel, + Outs: funcs[0].OutArgs(), ReturnType: funcs[0].ReturnType, }, nil } diff --git a/internal/endtoend/testdata/func_out_param/issue.md b/internal/endtoend/testdata/func_out_param/issue.md new file mode 100644 index 0000000000..1b8779f2ff --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/issue.md @@ -0,0 +1 @@ +https://github.com/sqlc-dev/sqlc/issues/1654 diff --git a/internal/endtoend/testdata/func_out_param/pgx/go/db.go b/internal/endtoend/testdata/func_out_param/pgx/go/db.go new file mode 100644 index 0000000000..8a010ccc48 --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/pgx/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/func_out_param/pgx/go/models.go b/internal/endtoend/testdata/func_out_param/pgx/go/models.go new file mode 100644 index 0000000000..b320134bd5 --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/pgx/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package querytest + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Author struct { + ID int64 + Name string + Bio pgtype.Text +} diff --git a/internal/endtoend/testdata/func_out_param/pgx/go/query.sql.go b/internal/endtoend/testdata/func_out_param/pgx/go/query.sql.go new file mode 100644 index 0000000000..b19b4a0704 --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/pgx/go/query.sql.go @@ -0,0 +1,30 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: query.sql + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createAuthor = `-- name: CreateAuthor :one +SELECT id FROM add_author ( + $1, $2 +) +` + +type CreateAuthorParams struct { + Name string + Bio string +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (pgtype.Int4, error) { + row := q.db.QueryRow(ctx, createAuthor, arg.Name, arg.Bio) + var id pgtype.Int4 + err := row.Scan(&id) + return id, err +} diff --git a/internal/endtoend/testdata/func_out_param/pgx/query.sql b/internal/endtoend/testdata/func_out_param/pgx/query.sql new file mode 100644 index 0000000000..b23c7dca49 --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/pgx/query.sql @@ -0,0 +1,4 @@ +-- name: CreateAuthor :one +SELECT * FROM add_author ( + sqlc.arg(name), sqlc.arg(bio) +); diff --git a/internal/endtoend/testdata/func_out_param/pgx/schema.sql b/internal/endtoend/testdata/func_out_param/pgx/schema.sql new file mode 100644 index 0000000000..97c1022d14 --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/pgx/schema.sql @@ -0,0 +1,14 @@ +-- Example queries for sqlc +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + bio text +); + +CREATE OR REPLACE FUNCTION add_author (name text, bio text, out id int) +AS $$ +DECLARE +BEGIN + id = 123; +END; +$$ LANGUAGE plpgsql; diff --git a/internal/endtoend/testdata/func_out_param/pgx/sqlc.yaml b/internal/endtoend/testdata/func_out_param/pgx/sqlc.yaml new file mode 100644 index 0000000000..5dc63e3f91 --- /dev/null +++ b/internal/endtoend/testdata/func_out_param/pgx/sqlc.yaml @@ -0,0 +1,10 @@ +version: "2" +sql: + - engine: "postgresql" + schema: "schema.sql" + queries: "query.sql" + gen: + go: + package: "querytest" + out: "go" + sql_package: "pgx/v5" diff --git a/internal/sql/catalog/func.go b/internal/sql/catalog/func.go index 7cc712492d..e170777311 100644 --- a/internal/sql/catalog/func.go +++ b/internal/sql/catalog/func.go @@ -39,6 +39,17 @@ func (f *Function) InArgs() []*Argument { return args } +func (f *Function) OutArgs() []*Argument { + var args []*Argument + for _, a := range f.Args { + switch a.Mode { + case ast.FuncParamOut: + args = append(args, a) + } + } + return args +} + func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error { ns := stmt.Func.Schema if ns == "" {