From d13e1c51bef0f6295f4738f113dcae8be5d1b8a9 Mon Sep 17 00:00:00 2001 From: Kirill Dubovikov Date: Thu, 19 May 2022 20:33:19 +0300 Subject: [PATCH] fix: nullable return types for groupby functions --- internal/compiler/output_columns.go | 26 ++++++++++++++++++++------ internal/tools/sqlc-pg-gen/main.go | 26 ++++++++++++++++---------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index a01675645b..2c309a61bc 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -213,12 +213,26 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { } fun, err := qc.catalog.ResolveFuncCall(n) if err == nil { - cols = append(cols, &Column{ - Name: name, - DataType: dataType(fun.ReturnType), - NotNull: !fun.ReturnTypeNullable, - IsFuncCall: true, - }) + if len(fun.Args) > 0 { + ref := n.Args.Items[0].(*ast.ColumnRef) + columns, err := outputColumnRefs(res, tables, ref) + if err != nil { + return nil, err + } + cols = append(cols, &Column{ + Name: name, + DataType: columns[0].DataType, + NotNull: !fun.ReturnTypeNullable, + IsFuncCall: true, + }) + } else { + cols = append(cols, &Column{ + Name: name, + DataType: dataType(fun.ReturnType), + NotNull: !fun.ReturnTypeNullable, + IsFuncCall: true, + }) + } } else { cols = append(cols, &Column{ Name: name, diff --git a/internal/tools/sqlc-pg-gen/main.go b/internal/tools/sqlc-pg-gen/main.go index 990920d1d0..79c4e08882 100644 --- a/internal/tools/sqlc-pg-gen/main.go +++ b/internal/tools/sqlc-pg-gen/main.go @@ -23,7 +23,8 @@ SELECT p.proname as name, format_type(p.prorettype, NULL), array(select format_type(unnest(p.proargtypes), NULL)), p.proargnames, - p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs] + p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs], + CASE WHEN p.prokind = 'a' THEN TRUE ELSE FALSE END FROM pg_catalog.pg_proc p LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace WHERE n.nspname OPERATOR(pg_catalog.~) '^(pg_catalog)$' @@ -49,7 +50,8 @@ SELECT p.proname as name, format_type(p.prorettype, NULL), array(select format_type(unnest(p.proargtypes), NULL)), p.proargnames, - p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs] + p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs], + false FROM pg_catalog.pg_proc p JOIN extension_funcs ef ON ef.oid = p.oid WHERE p.proargmodes IS NULL @@ -86,6 +88,7 @@ func {{.Name}}() *catalog.Schema { {{end}} }, ReturnType: &ast.TypeName{Name: "{{.ReturnType.Name}}"}, + ReturnTypeNullable: {{.ReturnTypeNullable}}, }, {{- end}} } @@ -127,11 +130,12 @@ func main() { } type Proc struct { - Name string - ReturnType string - ArgTypes []string - ArgNames []string - HasDefault []string + Name string + ReturnType string + ArgTypes []string + ArgNames []string + HasDefault []string + ReturnsNull bool } func clean(arg string) string { @@ -144,9 +148,10 @@ func clean(arg string) string { func (p Proc) Func() catalog.Function { return catalog.Function{ - Name: p.Name, - Args: p.Args(), - ReturnType: &ast.TypeName{Name: clean(p.ReturnType)}, + Name: p.Name, + Args: p.Args(), + ReturnType: &ast.TypeName{Name: clean(p.ReturnType)}, + ReturnTypeNullable: p.ReturnsNull, } } @@ -185,6 +190,7 @@ func scanFuncs(rows pgx.Rows) ([]catalog.Function, error) { &p.ArgTypes, &p.ArgNames, &p.HasDefault, + &p.ReturnsNull, ) if err != nil { return nil, err