Skip to content

feat(codegen): Remove Go-specific overrides from codegen proto #2929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ jobs:
- name: install ./...
run: go install ./...

- name: build internal/endtoend
run: go build ./...
working-directory: internal/endtoend/testdata

- name: test ./...
run: gotestsum --junitfile junit.xml -- --tags=examples -timeout 20m ./...
env:
Expand All @@ -65,16 +69,6 @@ jobs:
CI_SQLC_AUTH_TOKEN: ${{ secrets.CI_SQLC_AUTH_TOKEN }}
SQLC_AUTH_TOKEN: ${{ secrets.CI_SQLC_AUTH_TOKEN }}

- name: build internal/endtoend
run: go build ./...
working-directory: internal/endtoend/testdata

- name: report
if: false
run: ./scripts/report.sh
env:
BUILDKITE_ANALYTICS_TOKEN: ${{ secrets.BUILDKITE_ANALYTICS_TOKEN }}

vuln_check:
runs-on: ubuntu-latest
timeout-minutes: 5
Expand Down
19 changes: 18 additions & 1 deletion internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,19 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re

opts, err := convert.YAMLtoJSON(sql.Plugin.Options)
if err != nil {
return "", nil, fmt.Errorf("invalid plugin options")
return "", nil, fmt.Errorf("invalid plugin options: %w", err)
}
req.PluginOptions = opts

global, found := combo.Global.Options[plug.Name]
if found {
opts, err := convert.YAMLtoJSON(global)
if err != nil {
return "", nil, fmt.Errorf("invalid global options: %w", err)
}
req.GlobalOptions = opts
}

case sql.Gen.Go != nil:
out = combo.Go.Out
handler = ext.HandleFunc(golang.Generate)
Expand All @@ -424,6 +433,14 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
}
req.PluginOptions = opts

if combo.Global.Overrides.Go != nil {
opts, err := json.Marshal(combo.Global.Overrides.Go)
if err != nil {
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
}
req.GlobalOptions = opts
}

case sql.Gen.JSON != nil:
out = combo.JSON.Out
handler = ext.HandleFunc(genjson.Generate)
Expand Down
66 changes: 5 additions & 61 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package cmd

import (
"strings"

"github.com/sqlc-dev/sqlc/internal/compiler"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/config/convert"
Expand All @@ -11,53 +9,13 @@ import (
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
)

func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
var column string
var table plugin.Identifier

if o.Column != "" {
colParts := strings.Split(o.Column, ".")
switch len(colParts) {
case 2:
table.Schema = r.Catalog.DefaultSchema
table.Name = colParts[0]
column = colParts[1]
case 3:
table.Schema = colParts[0]
table.Name = colParts[1]
column = colParts[2]
case 4:
table.Catalog = colParts[0]
table.Schema = colParts[1]
table.Name = colParts[2]
column = colParts[3]
}
}
return &plugin.Override{
CodeType: "", // FIXME
DbType: o.DBType,
Nullable: o.Nullable,
Unsigned: o.Unsigned,
Column: o.Column,
ColumnName: column,
Table: &table,
GoType: pluginGoType(o),
}
}

func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings {
var over []*plugin.Override
for _, o := range cs.Overrides {
over = append(over, pluginOverride(r, o))
}
return &plugin.Settings{
Version: cs.Global.Version,
Engine: string(cs.Package.Engine),
Schema: []string(cs.Package.Schema),
Queries: []string(cs.Package.Queries),
Overrides: over,
Rename: cs.Rename,
Codegen: pluginCodegen(cs, cs.Codegen),
Version: cs.Global.Version,
Engine: string(cs.Package.Engine),
Schema: []string(cs.Package.Schema),
Queries: []string(cs.Package.Queries),
Codegen: pluginCodegen(cs, cs.Codegen),
}
}

Expand Down Expand Up @@ -101,20 +59,6 @@ func pluginWASM(p config.Plugin) *plugin.Codegen_WASM {
return nil
}

func pluginGoType(o config.Override) *plugin.ParsedGoType {
// Note that there is a slight mismatch between this and the
// proto api. The GoType on the override is the unparsed type,
// which could be a qualified path or an object, as per
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
return &plugin.ParsedGoType{
ImportPath: o.GoImportPath,
Package: o.GoPackage,
TypeName: o.GoTypeName,
BasicType: o.GoBasicType,
StructTags: o.GoStructTags,
}
}

func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
var schemas []*plugin.Schema
for _, s := range c.Schemas {
Expand Down
11 changes: 5 additions & 6 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
options, err := opts.ParseOpts(req)
options, err := opts.Parse(req)
if err != nil {
return nil, err
}
Expand All @@ -129,11 +129,10 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR

func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
i := &importer{
Settings: req.Settings,
Options: options,
Queries: queries,
Enums: enums,
Structs: structs,
Options: options,
Queries: queries,
Enums: enums,
Structs: structs,
}

tctx := tmplCtx{
Expand Down
18 changes: 11 additions & 7 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, col *plugin.Column) {
for _, oride := range req.Settings.Overrides {
func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) {
for _, override := range options.Overrides {
oride := override.ShimOverride
if oride.GoType.StructTags == nil {
continue
}
if !sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema) {
if !override.Matches(col.Table, req.Catalog.DefaultSchema) {
// Different table.
continue
}
Expand All @@ -34,15 +35,17 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, co

func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
// Check if the column's type has been overridden
for _, oride := range req.Settings.Overrides {
for _, override := range options.Overrides {
oride := override.ShimOverride

if oride.GoType.TypeName == "" {
continue
}
cname := col.Name
if col.OriginalName != "" {
cname = col.OriginalName
}
sameTable := sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema)
sameTable := override.Matches(col.Table, req.Catalog.DefaultSchema)
if oride.Column != "" && sdk.MatchString(oride.ColumnName, cname) && sameTable {
if col.IsSqlcSlice {
return "[]" + oride.GoType.TypeName
Expand All @@ -65,7 +68,8 @@ func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.
notNull := col.NotNull || col.IsArray

// package overrides have a higher precedence
for _, oride := range req.Settings.Overrides {
for _, override := range options.Overrides {
oride := override.ShimOverride
if oride.GoType.TypeName == "" {
continue
}
Expand All @@ -77,7 +81,7 @@ func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.
// TODO: Extend the engine interface to handle types
switch req.Settings.Engine {
case "mysql":
return mysqlType(req, col)
return mysqlType(req, options, col)
case "postgresql":
return postgresType(req, options, col)
case "sqlite":
Expand Down
29 changes: 15 additions & 14 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
"github.com/sqlc-dev/sqlc/internal/metadata"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

type fileImports struct {
Expand Down Expand Up @@ -59,11 +58,10 @@ func mergeImports(imps ...fileImports) [][]ImportSpec {
}

type importer struct {
Settings *plugin.Settings
Options *opts.Options
Queries []Query
Enums []Enum
Structs []Struct
Options *opts.Options
Queries []Query
Enums []Enum
Structs []Struct
}

func (i *importer) usesType(typ string) bool {
Expand Down Expand Up @@ -157,7 +155,7 @@ var pqtypeTypes = map[string]struct{}{
"pqtype.NullRawMessage": {},
}

func buildImports(settings *plugin.Settings, options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
pkg := make(map[ImportSpec]struct{})
std := make(map[string]struct{})

Expand Down Expand Up @@ -201,7 +199,8 @@ func buildImports(settings *plugin.Settings, options *opts.Options, queries []Qu
}

overrideTypes := map[string]string{}
for _, o := range settings.Overrides {
for _, override := range options.Overrides {
o := override.ShimOverride
if o.GoType.BasicType || o.GoType.TypeName == "" {
continue
}
Expand All @@ -226,7 +225,9 @@ func buildImports(settings *plugin.Settings, options *opts.Options, queries []Qu
}

// Custom imports
for _, o := range settings.Overrides {
for _, override := range options.Overrides {
o := override.ShimOverride

if o.GoType.BasicType || o.GoType.TypeName == "" {
continue
}
Expand All @@ -241,7 +242,7 @@ func buildImports(settings *plugin.Settings, options *opts.Options, queries []Qu
}

func (i *importer) interfaceImports() fileImports {
std, pkg := buildImports(i.Settings, i.Options, i.Queries, func(name string) bool {
std, pkg := buildImports(i.Options, i.Queries, func(name string) bool {
for _, q := range i.Queries {
if q.hasRetType() {
if usesBatch([]Query{q}) {
Expand All @@ -266,7 +267,7 @@ func (i *importer) interfaceImports() fileImports {
}

func (i *importer) modelImports() fileImports {
std, pkg := buildImports(i.Settings, i.Options, nil, i.usesType)
std, pkg := buildImports(i.Options, nil, i.usesType)

if len(i.Enums) > 0 {
std["fmt"] = struct{}{}
Expand Down Expand Up @@ -305,7 +306,7 @@ func (i *importer) queryImports(filename string) fileImports {
}
}

std, pkg := buildImports(i.Settings, i.Options, gq, func(name string) bool {
std, pkg := buildImports(i.Options, gq, func(name string) bool {
for _, q := range gq {
if q.hasRetType() {
if q.Ret.EmitStruct() {
Expand Down Expand Up @@ -406,7 +407,7 @@ func (i *importer) copyfromImports() fileImports {
copyFromQueries = append(copyFromQueries, q)
}
}
std, pkg := buildImports(i.Settings, i.Options, copyFromQueries, func(name string) bool {
std, pkg := buildImports(i.Options, copyFromQueries, func(name string) bool {
for _, q := range copyFromQueries {
if q.hasRetType() {
if strings.HasPrefix(q.Ret.Type(), name) {
Expand Down Expand Up @@ -441,7 +442,7 @@ func (i *importer) batchImports() fileImports {
batchQueries = append(batchQueries, q)
}
}
std, pkg := buildImports(i.Settings, i.Options, batchQueries, func(name string) bool {
std, pkg := buildImports(i.Options, batchQueries, func(name string) bool {
for _, q := range batchQueries {
if q.hasRetType() {
if q.Ret.EmitStruct() {
Expand Down
11 changes: 6 additions & 5 deletions internal/codegen/golang/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package golang
import (
"log"

"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
"github.com/sqlc-dev/sqlc/internal/debug"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string {
func mysqlType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
unsigned := col.Unsigned
Expand Down Expand Up @@ -101,14 +102,14 @@ func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string {
if enum.Name == columnType {
if notNull {
if schema.Name == req.Catalog.DefaultSchema {
return StructName(enum.Name, req.Settings)
return StructName(enum.Name, options)
}
return StructName(schema.Name+"_"+enum.Name, req.Settings)
return StructName(schema.Name+"_"+enum.Name, options)
} else {
if schema.Name == req.Catalog.DefaultSchema {
return "Null" + StructName(enum.Name, req.Settings)
return "Null" + StructName(enum.Name, options)
}
return "Null" + StructName(schema.Name+"_"+enum.Name, req.Settings)
return "Null" + StructName(schema.Name+"_"+enum.Name, options)
}
}
}
Expand Down
Loading