Skip to content

feat(analyzer): Cache query analysis #2889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
546 changes: 546 additions & 0 deletions internal/analysis/analysis.pb.go

Large diffs are not rendered by default.

2,078 changes: 2,078 additions & 0 deletions internal/analysis/analysis_vtproto.pb.go

Large diffs are not rendered by default.

126 changes: 98 additions & 28 deletions internal/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,115 @@ package analyzer

import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"log/slog"
"os"
"path/filepath"

"google.golang.org/protobuf/proto"

"github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/cache"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/info"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/named"
)

type Column struct {
Name string
OriginalName string
DataType string
NotNull bool
Unsigned bool
IsArray bool
ArrayDims int
Comment string
Length *int
IsNamedParam bool
IsFuncCall bool

// XXX: Figure out what PostgreSQL calls `foo.id`
Scope string
Table *ast.TableName
TableAlias string
Type *ast.TypeName
EmbedTable *ast.TableName

IsSqlcSlice bool // is this sqlc.slice()
type CachedAnalyzer struct {
a Analyzer
config config.Config
configBytes []byte
db config.Database
}

func Cached(a Analyzer, c config.Config, db config.Database) *CachedAnalyzer {
return &CachedAnalyzer{
a: a,
config: c,
db: db,
}
}

type Parameter struct {
Number int
Column *Column
// Create a new error here

func (c *CachedAnalyzer) Analyze(ctx context.Context, n ast.Node, q string, schema []string, np *named.ParamSet) (*analysis.Analysis, error) {
result, rerun, err := c.analyze(ctx, n, q, schema, np)
if rerun {
if err != nil {
slog.Warn("first analysis failed with error", "err", err)
}
return c.a.Analyze(ctx, n, q, schema, np)
}
return result, err
}

func (c *CachedAnalyzer) analyze(ctx context.Context, n ast.Node, q string, schema []string, np *named.ParamSet) (*analysis.Analysis, bool, error) {
// Only cache queries for managed databases. We can't be certain the the
// database is in an unchanged state otherwise
if !c.db.Managed {
return nil, true, nil
}

dir, err := cache.AnalysisDir()
if err != nil {
return nil, true, err
}

if c.configBytes == nil {
c.configBytes, err = json.Marshal(c.config)
if err != nil {
return nil, true, err
}
}

// Calculate cache key
h := fnv.New64()
h.Write([]byte(info.Version))
h.Write(c.configBytes)
for _, m := range schema {
h.Write([]byte(m))
}
h.Write([]byte(q))

key := fmt.Sprintf("%x", h.Sum(nil))
path := filepath.Join(dir, key)
if _, err := os.Stat(path); err == nil {
contents, err := os.ReadFile(path)
if err != nil {
return nil, true, err
}
var a analysis.Analysis
if err := proto.Unmarshal(contents, &a); err != nil {
return nil, true, err
}
return &a, false, nil
}

result, err := c.a.Analyze(ctx, n, q, schema, np)

if err == nil {
contents, err := proto.Marshal(result)
if err != nil {
slog.Warn("unable to marshal analysis", "err", err)
return result, false, nil
}
if err := os.WriteFile(path, contents, 0644); err != nil {
slog.Warn("saving analysis to disk failed", "err", err)
return result, false, nil
}
}

return result, false, err
}

type Analysis struct {
Columns []Column
Params []Parameter
func (c *CachedAnalyzer) Close(ctx context.Context) error {
return c.a.Close(ctx)
}

type Analyzer interface {
Analyze(context.Context, ast.Node, string, []string, *named.ParamSet) (*Analysis, error)
Analyze(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysis.Analysis, error)
Close(context.Context) error
}
47 changes: 47 additions & 0 deletions internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package cache

import (
"fmt"
"os"
"path/filepath"
)

// The cache directory defaults to os.UserCacheDir(). This location can be
// overridden by the SQLCCACHE environment variable.
//
// Currently the cache stores two types of data: plugins and query analysis
func Dir() (string, error) {
cache := os.Getenv("SQLCCACHE")
if cache != "" {
return cache, nil
}
cacheHome, err := os.UserCacheDir()
if err != nil {
return "", err
}
return filepath.Join(cacheHome, "sqlc"), nil
}

func PluginsDir() (string, error) {
cacheRoot, err := Dir()
if err != nil {
return "", err
}
dir := filepath.Join(cacheRoot, "plugins")
if err := os.MkdirAll(dir, 0755); err != nil && !os.IsExist(err) {
return "", fmt.Errorf("failed to create %s directory: %w", dir, err)
}
return dir, nil
}

func AnalysisDir() (string, error) {
cacheRoot, err := Dir()
if err != nil {
return "", err
}
dir := filepath.Join(cacheRoot, "query_analysis")
if err := os.MkdirAll(dir, 0755); err != nil && !os.IsExist(err) {
return "", fmt.Errorf("failed to create %s directory: %w", dir, err)
}
return dir, nil
}
63 changes: 42 additions & 21 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package compiler
import (
"sort"

"github.com/sqlc-dev/sqlc/internal/analyzer"
analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/source"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
Expand All @@ -13,32 +13,54 @@ import (
)

type analysis struct {
Table *ast.TableName
Columns []*Column
QueryCatalog *QueryCatalog
Parameters []Parameter
Named *named.ParamSet
Query string
Table *ast.TableName
Columns []*Column
Parameters []Parameter
Named *named.ParamSet
Query string
}

func convertColumn(c analyzer.Column) *Column {
func convertTableName(id *analyzer.Identifier) *ast.TableName {
if id == nil {
return nil
}
return &ast.TableName{
Catalog: id.Catalog,
Schema: id.Schema,
Name: id.Name,
}
}

func convertTypeName(id *analyzer.Identifier) *ast.TypeName {
if id == nil {
return nil
}
return &ast.TypeName{
Catalog: id.Catalog,
Schema: id.Schema,
Name: id.Name,
}
}

func convertColumn(c *analyzer.Column) *Column {
length := int(c.Length)
return &Column{
Name: c.Name,
OriginalName: c.OriginalName,
DataType: c.DataType,
NotNull: c.NotNull,
Unsigned: c.Unsigned,
IsArray: c.IsArray,
ArrayDims: c.ArrayDims,
ArrayDims: int(c.ArrayDims),
Comment: c.Comment,
Length: c.Length,
Length: &length,
IsNamedParam: c.IsNamedParam,
IsFuncCall: c.IsFuncCall,
Scope: c.Scope,
Table: c.Table,
Table: convertTableName(c.Table),
TableAlias: c.TableAlias,
Type: c.Type,
EmbedTable: c.EmbedTable,
Type: convertTypeName(c.Type),
EmbedTable: convertTableName(c.EmbedTable),
IsSqlcSlice: c.IsSqlcSlice,
}
}
Expand All @@ -51,8 +73,8 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis {
var params []Parameter
for _, p := range a.Params {
params = append(params, Parameter{
Number: p.Number,
Column: convertColumn(*p.Column),
Number: int(p.Number),
Column: convertColumn(p.Column),
})
}
if len(prev.Columns) == len(cols) {
Expand Down Expand Up @@ -189,11 +211,10 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
}

return &analysis{
Table: table,
Columns: cols,
Parameters: params,
QueryCatalog: qc,
Query: expanded,
Named: namedParams,
Table: table,
Columns: cols,
Parameters: params,
Query: expanded,
Named: namedParams,
}, rerr
}
6 changes: 5 additions & 1 deletion internal/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err
c.catalog = postgresql.NewCatalog()
if conf.Database != nil {
if conf.Analyzer.Database == nil || *conf.Analyzer.Database {
c.analyzer = pganalyze.New(c.client, *conf.Database)
c.analyzer = analyzer.Cached(
pganalyze.New(c.client, *conf.Database),
combo.Global,
*conf.Database,
)
}
}
default:
Expand Down
1 change: 0 additions & 1 deletion internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,

var anlys *analysis
if c.analyzer != nil {
// TODO: Handle panics
inference, _ := c.inferQuery(raw, rawSQL)
if inference == nil {
inference = &analysis{}
Expand Down
18 changes: 9 additions & 9 deletions internal/engine/postgresql/analyzer/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"

core "github.com/sqlc-dev/sqlc/internal/analyzer"
core "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
Expand Down Expand Up @@ -250,14 +250,14 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
dt, isArray, _ := parseType(col.DataType)
notNull := col.NotNull
name := field.Name
result.Columns = append(result.Columns, core.Column{
result.Columns = append(result.Columns, &core.Column{
Name: name,
OriginalName: field.Name,
DataType: dt,
NotNull: notNull,
IsArray: isArray,
ArrayDims: col.ArrayDims,
Table: &ast.TableName{
ArrayDims: int32(col.ArrayDims),
Table: &core.Identifier{
Schema: tbl.SchemaName,
Name: tbl.TableName,
},
Expand All @@ -271,13 +271,13 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
notNull := false
name := field.Name
dt, isArray, dims := parseType(dataType)
result.Columns = append(result.Columns, core.Column{
result.Columns = append(result.Columns, &core.Column{
Name: name,
OriginalName: field.Name,
DataType: dt,
NotNull: notNull,
IsArray: isArray,
ArrayDims: dims,
ArrayDims: int32(dims),
})
}
}
Expand All @@ -293,13 +293,13 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
if ps != nil {
name, _ = ps.NameFor(i + 1)
}
result.Params = append(result.Params, core.Parameter{
Number: i + 1,
result.Params = append(result.Params, &core.Parameter{
Number: int32(i + 1),
Column: &core.Column{
Name: name,
DataType: dt,
IsArray: isArray,
ArrayDims: dims,
ArrayDims: int32(dims),
NotNull: notNull,
},
})
Expand Down
Loading