Skip to content

fix: closing db on finalizer and memory leak #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func (sc *statementCache) newDB(sqldb *sql.DB) *DB {
}

// lookupStmt checks if a Statement has been prepared on the db driver with the
// given primedSQL. If it has, the driver-prepared sql.Stmt is returned.
func (sc *statementCache) lookupStmt(db *DB, s *Statement, primedSQL string) (stmt *sql.Stmt, ok bool) {
// given primedSQL. If it has, the driverStmt is returned.
func (sc *statementCache) lookupStmt(db *DB, s *Statement, primedSQL string) (dStmt *driverStmt, ok bool) {
// The Statement cache ID is only removed from stmtDBCache when the
// finalizer is run. The Statement's cache ID must be in the stmtDBCache
// since we hold a reference to the Statement. It is therefore safe to
Expand All @@ -116,27 +116,32 @@ func (sc *statementCache) lookupStmt(db *DB, s *Statement, primedSQL string) (st
if !ok || ds.sql != primedSQL {
return nil, false
}
return ds.stmt, ok
return ds, ok
}

// driverPrepareStatement prepares a statement on the database and then stores
// the prepared *sql.Stmt in the cache.
func (sc *statementCache) driverPrepareStmt(ctx context.Context, db *DB, s *Statement, primedSQL string) (*sql.Stmt, error) {
// the prepared driverStmt in the cache.
func (sc *statementCache) driverPrepareStmt(ctx context.Context, db *DB, s *Statement, primedSQL string) (*driverStmt, error) {
sqlstmt, err := db.sqldb.PrepareContext(ctx, primedSQL)
if err != nil {
return nil, err
}

sc.mutex.Lock()
defer sc.mutex.Unlock()
// If there is already a statement in the cache, replace it with ours.
if driverStmt, ok := sc.stmtDBCache[s.cacheID][db.cacheID]; ok {
// Set a finalizer on the statement we evict from the cache to close it
// once current users have finished with it.
runtime.SetFinalizer(driverStmt.stmt, (*sql.Stmt).Close)

// If there is already a driver statement in the cache for this Statement's
// cache ID, set a finalizer on the driverStmt and evict it from the cache,
// replacing it with the newly generated driverStmt. The finalizer ensures
// that the sql.Stmt in the driverStmt is closed once concurrent users have
// finished with it.
if ds, ok := sc.stmtDBCache[s.cacheID][db.cacheID]; ok {
runtime.SetFinalizer(ds, closeDriverStmt)
}
sc.stmtDBCache[s.cacheID][db.cacheID] = &driverStmt{sql: primedSQL, stmt: sqlstmt}
ds := &driverStmt{sql: primedSQL, stmt: sqlstmt}
sc.stmtDBCache[s.cacheID][db.cacheID] = ds
sc.dbStmtCache[db.cacheID][s.cacheID] = true
return sqlstmt, nil
return ds, nil
}

// removeAndCloseStmtFunc removes and closes all sql.Stmt objects associated
Expand All @@ -153,8 +158,7 @@ func (sc *statementCache) removeAndCloseStmtFunc(s *Statement) {
}

// removeAndCloseDBFunc closes and removes from the cache all sql.Stmt objects
// prepared on the database, removes the database from then cache, then closes
// the sql.DB.
// prepared on the database, removes the database from then cache.
func (sc *statementCache) removeAndCloseDBFunc(db *DB) {
sc.mutex.Lock()
defer sc.mutex.Unlock()
Expand All @@ -165,5 +169,9 @@ func (sc *statementCache) removeAndCloseDBFunc(db *DB) {
delete(dbCache, db.cacheID)
}
delete(sc.dbStmtCache, db.cacheID)
db.sqldb.Close()
}

// closeDriverStmt closes the underlying sql.Stmt of the driverStmt.
func closeDriverStmt(ds *driverStmt) {
ds.stmt.Close()
}
38 changes: 0 additions & 38 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,44 +79,6 @@ func (s *CacheSuite) TestPreparedStatementReuse(c *C) {
s.checkDriverStmtsAllClosed(c)
}

func (s *CacheSuite) TestClosingDB(c *C) {
stmt, err := Prepare(`SELECT 'test'`)
c.Assert(err, IsNil)

var dbID uint64
// For a Statement or DB to be removed from the cache it needs to go out of
// scope and be garbage collected. A function is used to "forget" the
// statement.
func() {
db := s.openDB(c)
dbID = db.cacheID

// Start a query with stmt on db. This will prepare the stmt on the db.
err = db.Query(nil, stmt).Run()
c.Assert(err, IsNil)

// Check a statement is in the cache and a prepared statement has been
// opened on the DB.
s.checkStmtInCache(c, db.cacheID, stmt.cacheID)
s.checkNumDBStmts(c, db.cacheID, 1)
s.checkDriverStmtsOpened(c, 1)
}()

s.triggerFinalizers()
s.checkDBNotInCache(c, dbID)
s.checkDriverStmtsAllClosed(c)

// Check that the statement runs fine on a new DB.
db := s.openDB(c)
err = db.Query(nil, stmt).Run()
c.Assert(err, IsNil)

// Check the statement has been added to the cache for the new DB.
s.checkStmtInCache(c, db.cacheID, stmt.cacheID)
s.checkNumDBStmts(c, db.cacheID, 1)
s.checkDriverStmtsOpened(c, 2)
}

func (s *CacheSuite) TestStatementPreparedAndClosed(c *C) {
db := s.openDB(c)

Expand Down
38 changes: 22 additions & 16 deletions sqlair.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ func (db *DB) PlainDB() *sql.DB {
// Query represents a query on a database. It is designed to be run once and
// used immediately since it contains the query context.
type Query struct {
// run executes the Query against the DB or the TX.
run func(context.Context) (*sql.Rows, sql.Result, error)
// run executes the Query against the DB or the TX. It returns the results
// and a pointer to the driverStmt used to run the query if it needs to be
// kept in memory.
run func(context.Context) (*sql.Rows, sql.Result, *driverStmt, error)
ctx context.Context
err error
pq *expr.PrimedQuery
Expand All @@ -120,6 +122,10 @@ type Iterator struct {
err error
result sql.Result
started bool
// ds is the driverStmt used to run the query. The Iterator holds onto this
// so that it cannot be closed by finalizer while the rows are being
// iterated over. This finalizer can be set in the cache.
ds *driverStmt
}

// Query builds a new query from a context, a [Statement] and the input
Expand All @@ -138,22 +144,22 @@ func (db *DB) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
return &Query{ctx: ctx, err: err}
}

run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, err error) {
run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, ds *driverStmt, err error) {
primedSQL := pq.SQL()
sqlstmt, ok := stmtCache.lookupStmt(db, s, primedSQL)
ds, ok := stmtCache.lookupStmt(db, s, primedSQL)
if !ok {
sqlstmt, err = stmtCache.driverPrepareStmt(ctx, db, s, primedSQL)
ds, err = stmtCache.driverPrepareStmt(ctx, db, s, primedSQL)
if err != nil {
return nil, nil, err
return nil, nil, ds, err
}
}

if pq.HasOutputs() {
rows, err = sqlstmt.QueryContext(innerCtx, pq.Params()...)
rows, err = ds.stmt.QueryContext(innerCtx, pq.Params()...)
} else {
result, err = sqlstmt.ExecContext(innerCtx, pq.Params()...)
result, err = ds.stmt.ExecContext(innerCtx, pq.Params()...)
}
return rows, result, err
return rows, result, ds, err
}

return &Query{pq: pq, run: run, ctx: ctx, err: nil}
Expand Down Expand Up @@ -215,7 +221,7 @@ func (q *Query) Iter() *Iterator {
}

var cols []string
rows, result, err := q.run(q.ctx)
rows, result, ds, err := q.run(q.ctx)
if q.pq.HasOutputs() {
if err == nil { // if err IS nil
cols, err = rows.Columns()
Expand All @@ -225,7 +231,7 @@ func (q *Query) Iter() *Iterator {
return &Iterator{pq: q.pq, err: err}
}

return &Iterator{pq: q.pq, rows: rows, cols: cols, err: err, result: result}
return &Iterator{pq: q.pq, rows: rows, cols: cols, err: err, result: result, ds: ds}
}

// Next prepares the next row for [Iterator.Get]. If an error occurs during
Expand Down Expand Up @@ -486,28 +492,28 @@ func (tx *TX) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
return &Query{ctx: ctx, err: err}
}

run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, err error) {
sqlstmt, ok := stmtCache.lookupStmt(tx.db, s, pq.SQL())
run := func(innerCtx context.Context) (rows *sql.Rows, result sql.Result, ds *driverStmt, err error) {
ds, ok := stmtCache.lookupStmt(tx.db, s, pq.SQL())
if ok {
// Register the prepared statement on the transaction. This function
// does not resend the prepare request to the database.
// The txstmt is closed by database/sql when the transaction is
// commited or rolled back.
txstmt := tx.sqltx.Stmt(sqlstmt)
txstmt := tx.sqltx.Stmt(ds.stmt)
if pq.HasOutputs() {
rows, err = txstmt.QueryContext(innerCtx, pq.Params()...)
} else {
result, err = txstmt.ExecContext(innerCtx, pq.Params()...)
}
return rows, result, err
return rows, result, ds, err
}

if pq.HasOutputs() {
rows, err = tx.sqltx.QueryContext(innerCtx, pq.SQL(), pq.Params()...)
} else {
result, err = tx.sqltx.ExecContext(innerCtx, pq.SQL(), pq.Params()...)
}
return rows, result, err
return rows, result, nil, err
}

return &Query{pq: pq, ctx: ctx, run: run, err: nil}
Expand Down
Loading