Skip to content

Commit 93cd8f3

Browse files
authored
TST: refactor drop_table in sql test (#43083)
1 parent a527c01 commit 93cd8f3

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

pandas/tests/io/test_sql.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,24 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str):
249249
conn.execute(stmt)
250250

251251

252+
def count_rows(conn, table_name: str):
253+
stmt = f"SELECT count(*) AS count_1 FROM {table_name}"
254+
if isinstance(conn, sqlite3.Connection):
255+
cur = conn.cursor()
256+
result = cur.execute(stmt)
257+
else:
258+
from sqlalchemy import text
259+
from sqlalchemy.engine import Engine
260+
261+
stmt = text(stmt)
262+
if isinstance(conn, Engine):
263+
with conn.connect() as conn:
264+
result = conn.execute(stmt)
265+
else:
266+
result = conn.execute(stmt)
267+
return result.fetchone()[0]
268+
269+
252270
@pytest.fixture
253271
def iris_path(datapath):
254272
iris_path = datapath("io", "data", "csv", "iris.csv")
@@ -415,12 +433,6 @@ class PandasSQLTest:
415433
416434
"""
417435

418-
def _get_exec(self):
419-
if hasattr(self.conn, "execute"):
420-
return self.conn
421-
else:
422-
return self.conn.cursor()
423-
424436
@pytest.fixture
425437
def load_iris_data(self, iris_path):
426438
if not hasattr(self, "conn"):
@@ -451,14 +463,6 @@ def _check_iris_loaded_frame(self, iris_frame):
451463
assert issubclass(pytype, np.floating)
452464
tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"])
453465

454-
def _count_rows(self, table_name):
455-
result = (
456-
self._get_exec()
457-
.execute(f"SELECT count(*) AS count_1 FROM {table_name}")
458-
.fetchone()
459-
)
460-
return result[0]
461-
462466
def _read_sql_iris(self):
463467
iris_frame = self.pandasSQL.read_query("SELECT * FROM iris")
464468
self._check_iris_loaded_frame(iris_frame)
@@ -487,7 +491,7 @@ def _to_sql(self, test_frame1, method=None):
487491
assert self.pandasSQL.has_table("test_frame1")
488492

489493
num_entries = len(test_frame1)
490-
num_rows = self._count_rows("test_frame1")
494+
num_rows = count_rows(self.conn, "test_frame1")
491495
assert num_rows == num_entries
492496

493497
# Nuke table
@@ -518,7 +522,7 @@ def _to_sql_replace(self, test_frame1):
518522
assert self.pandasSQL.has_table("test_frame1")
519523

520524
num_entries = len(test_frame1)
521-
num_rows = self._count_rows("test_frame1")
525+
num_rows = count_rows(self.conn, "test_frame1")
522526

523527
assert num_rows == num_entries
524528
self.drop_table("test_frame1")
@@ -534,7 +538,7 @@ def _to_sql_append(self, test_frame1):
534538
assert self.pandasSQL.has_table("test_frame1")
535539

536540
num_entries = 2 * len(test_frame1)
537-
num_rows = self._count_rows("test_frame1")
541+
num_rows = count_rows(self.conn, "test_frame1")
538542

539543
assert num_rows == num_entries
540544
self.drop_table("test_frame1")
@@ -554,7 +558,7 @@ def sample(pd_table, conn, keys, data_iter):
554558

555559
assert check == [1]
556560
num_entries = len(test_frame1)
557-
num_rows = self._count_rows("test_frame1")
561+
num_rows = count_rows(self.conn, "test_frame1")
558562
assert num_rows == num_entries
559563
# Nuke table
560564
self.drop_table("test_frame1")
@@ -570,7 +574,7 @@ def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs):
570574
assert self.pandasSQL.has_table("test_frame1")
571575

572576
num_entries = len(test_frame1)
573-
num_rows = self._count_rows("test_frame1")
577+
num_rows = count_rows(self.conn, "test_frame1")
574578
assert num_rows == num_entries
575579

576580
# Nuke table
@@ -695,7 +699,7 @@ def test_to_sql_replace(self, test_frame1):
695699
assert sql.has_table("test_frame3", self.conn)
696700

697701
num_entries = len(test_frame1)
698-
num_rows = self._count_rows("test_frame3")
702+
num_rows = count_rows(self.conn, "test_frame3")
699703

700704
assert num_rows == num_entries
701705

@@ -707,7 +711,7 @@ def test_to_sql_append(self, test_frame1):
707711
assert sql.has_table("test_frame4", self.conn)
708712

709713
num_entries = 2 * len(test_frame1)
710-
num_rows = self._count_rows("test_frame4")
714+
num_rows = count_rows(self.conn, "test_frame4")
711715

712716
assert num_rows == num_entries
713717

0 commit comments

Comments
 (0)