diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 1fa02db3b3af41..fb14c0bc71d99e 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -887,6 +887,14 @@ def test_rowcount_executemany(self): self.cu.executemany("insert into test(name) values (?)", [(1,), (2,), (3,)]) self.assertEqual(self.cu.rowcount, 3) + @unittest.skipIf(sqlite.sqlite_version_info < (3, 35, 0), + "Requires SQLite 3.35.0 or newer") + def test_rowcount_update_returning(self): + # gh-93421: rowcount is updated correctly for UPDATE...RETURNING queries + self.cu.execute("update test set name='bar' where name='foo' returning 1") + self.assertEqual(self.cu.fetchone()[0], 1) + self.assertEqual(self.cu.rowcount, 1) + def test_total_changes(self): self.cu.execute("insert into test(name) values ('foo')") self.cu.execute("insert into test(name) values ('foo')") diff --git a/Misc/NEWS.d/next/Library/2022-06-05-22-22-42.gh-issue-93421.43UO_8.rst b/Misc/NEWS.d/next/Library/2022-06-05-22-22-42.gh-issue-93421.43UO_8.rst new file mode 100644 index 00000000000000..085ea94c03c64e --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-06-05-22-22-42.gh-issue-93421.43UO_8.rst @@ -0,0 +1,2 @@ +Fix :data:`sqlite3.Cursor.rowcount` for ``UPDATE ... RETURNING``` SQL +queries. Patch by Erlend E. Aasland. diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index c58def5f0362f1..115f0c83f095d9 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -776,6 +776,16 @@ stmt_mark_dirty(pysqlite_Statement *self) self->in_use = 1; } +static inline sqlite3_int64 +total_changes(sqlite3 *db) +{ +#if SQLITE_VERSION_NUMBER >= 3037000 + return sqlite3_total_changes64(db); +#else + return (sqlite3_int64)sqlite3_total_changes(db); +#endif +} + PyObject * _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation, PyObject* second_argument) { @@ -835,10 +845,9 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation stmt_reset(self->statement); } - /* reset description and rowcount */ + /* reset description */ Py_INCREF(Py_None); Py_SETREF(self->description, Py_None); - self->rowcount = 0L; if (self->statement) { (void)stmt_reset(self->statement); @@ -879,6 +888,14 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation } } + if (self->statement->is_dml) { + // Save current row count + self->rowcount = total_changes(self->connection->db); + } + else { + self->rowcount = -1L; + } + while (1) { parameters = PyIter_Next(parameters_iter); if (!parameters) { @@ -944,12 +961,6 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation } } - if (self->statement->is_dml) { - self->rowcount += (long)sqlite3_changes(self->connection->db); - } else { - self->rowcount= -1L; - } - if (rc == SQLITE_DONE && !multiple) { stmt_reset(self->statement); Py_CLEAR(self->statement); @@ -1308,6 +1319,19 @@ pysqlite_cursor_close_impl(pysqlite_Cursor *self) Py_RETURN_NONE; } +static PyObject * +get_rowcount(pysqlite_Cursor *self, void *Py_UNUSED(closure)) +{ + if (!check_cursor(self)) { + return NULL; + } + if (self->rowcount == -1L) { + return PyLong_FromLong(-1L); + } + sqlite3_int64 changes = total_changes(self->connection->db); + return PyLong_FromLong(changes - self->rowcount); +} + static PyMethodDef cursor_methods[] = { PYSQLITE_CURSOR_CLOSE_METHODDEF PYSQLITE_CURSOR_EXECUTEMANY_METHODDEF @@ -1327,12 +1351,16 @@ static struct PyMemberDef cursor_members[] = {"description", T_OBJECT, offsetof(pysqlite_Cursor, description), READONLY}, {"arraysize", T_INT, offsetof(pysqlite_Cursor, arraysize), 0}, {"lastrowid", T_OBJECT, offsetof(pysqlite_Cursor, lastrowid), READONLY}, - {"rowcount", T_LONG, offsetof(pysqlite_Cursor, rowcount), READONLY}, {"row_factory", T_OBJECT, offsetof(pysqlite_Cursor, row_factory), 0}, {"__weaklistoffset__", T_PYSSIZET, offsetof(pysqlite_Cursor, in_weakreflist), READONLY}, {NULL} }; +static PyGetSetDef cursor_getset[] = { + {"rowcount", (getter)get_rowcount, (setter)NULL}, + {NULL}, +}; + static const char cursor_doc[] = PyDoc_STR("SQLite database cursor class."); @@ -1346,6 +1374,7 @@ static PyType_Slot cursor_slots[] = { {Py_tp_init, pysqlite_cursor_init}, {Py_tp_traverse, cursor_traverse}, {Py_tp_clear, cursor_clear}, + {Py_tp_getset, cursor_getset}, {0, NULL}, }; diff --git a/Modules/_sqlite/cursor.h b/Modules/_sqlite/cursor.h index 0bcdddc3e29595..3ef1caf1a1edd9 100644 --- a/Modules/_sqlite/cursor.h +++ b/Modules/_sqlite/cursor.h @@ -38,7 +38,7 @@ typedef struct PyObject* row_cast_map; int arraysize; PyObject* lastrowid; - long rowcount; + sqlite3_int64 rowcount; // Saved row count PyObject* row_factory; pysqlite_Statement* statement; int closed;