@@ -249,6 +249,24 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str):
249
249
conn .execute (stmt )
250
250
251
251
252
+ def count_rows (conn , table_name : str ):
253
+ stmt = f"SELECT count(*) AS count_1 FROM { table_name } "
254
+ if isinstance (conn , sqlite3 .Connection ):
255
+ cur = conn .cursor ()
256
+ result = cur .execute (stmt )
257
+ else :
258
+ from sqlalchemy import text
259
+ from sqlalchemy .engine import Engine
260
+
261
+ stmt = text (stmt )
262
+ if isinstance (conn , Engine ):
263
+ with conn .connect () as conn :
264
+ result = conn .execute (stmt )
265
+ else :
266
+ result = conn .execute (stmt )
267
+ return result .fetchone ()[0 ]
268
+
269
+
252
270
@pytest .fixture
253
271
def iris_path (datapath ):
254
272
iris_path = datapath ("io" , "data" , "csv" , "iris.csv" )
@@ -415,12 +433,6 @@ class PandasSQLTest:
415
433
416
434
"""
417
435
418
- def _get_exec (self ):
419
- if hasattr (self .conn , "execute" ):
420
- return self .conn
421
- else :
422
- return self .conn .cursor ()
423
-
424
436
@pytest .fixture
425
437
def load_iris_data (self , iris_path ):
426
438
if not hasattr (self , "conn" ):
@@ -451,14 +463,6 @@ def _check_iris_loaded_frame(self, iris_frame):
451
463
assert issubclass (pytype , np .floating )
452
464
tm .equalContents (row .values , [5.1 , 3.5 , 1.4 , 0.2 , "Iris-setosa" ])
453
465
454
- def _count_rows (self , table_name ):
455
- result = (
456
- self ._get_exec ()
457
- .execute (f"SELECT count(*) AS count_1 FROM { table_name } " )
458
- .fetchone ()
459
- )
460
- return result [0 ]
461
-
462
466
def _read_sql_iris (self ):
463
467
iris_frame = self .pandasSQL .read_query ("SELECT * FROM iris" )
464
468
self ._check_iris_loaded_frame (iris_frame )
@@ -487,7 +491,7 @@ def _to_sql(self, test_frame1, method=None):
487
491
assert self .pandasSQL .has_table ("test_frame1" )
488
492
489
493
num_entries = len (test_frame1 )
490
- num_rows = self ._count_rows ( "test_frame1" )
494
+ num_rows = count_rows ( self .conn , "test_frame1" )
491
495
assert num_rows == num_entries
492
496
493
497
# Nuke table
@@ -518,7 +522,7 @@ def _to_sql_replace(self, test_frame1):
518
522
assert self .pandasSQL .has_table ("test_frame1" )
519
523
520
524
num_entries = len (test_frame1 )
521
- num_rows = self ._count_rows ( "test_frame1" )
525
+ num_rows = count_rows ( self .conn , "test_frame1" )
522
526
523
527
assert num_rows == num_entries
524
528
self .drop_table ("test_frame1" )
@@ -534,7 +538,7 @@ def _to_sql_append(self, test_frame1):
534
538
assert self .pandasSQL .has_table ("test_frame1" )
535
539
536
540
num_entries = 2 * len (test_frame1 )
537
- num_rows = self ._count_rows ( "test_frame1" )
541
+ num_rows = count_rows ( self .conn , "test_frame1" )
538
542
539
543
assert num_rows == num_entries
540
544
self .drop_table ("test_frame1" )
@@ -554,7 +558,7 @@ def sample(pd_table, conn, keys, data_iter):
554
558
555
559
assert check == [1 ]
556
560
num_entries = len (test_frame1 )
557
- num_rows = self ._count_rows ( "test_frame1" )
561
+ num_rows = count_rows ( self .conn , "test_frame1" )
558
562
assert num_rows == num_entries
559
563
# Nuke table
560
564
self .drop_table ("test_frame1" )
@@ -570,7 +574,7 @@ def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs):
570
574
assert self .pandasSQL .has_table ("test_frame1" )
571
575
572
576
num_entries = len (test_frame1 )
573
- num_rows = self ._count_rows ( "test_frame1" )
577
+ num_rows = count_rows ( self .conn , "test_frame1" )
574
578
assert num_rows == num_entries
575
579
576
580
# Nuke table
@@ -695,7 +699,7 @@ def test_to_sql_replace(self, test_frame1):
695
699
assert sql .has_table ("test_frame3" , self .conn )
696
700
697
701
num_entries = len (test_frame1 )
698
- num_rows = self ._count_rows ( "test_frame3" )
702
+ num_rows = count_rows ( self .conn , "test_frame3" )
699
703
700
704
assert num_rows == num_entries
701
705
@@ -707,7 +711,7 @@ def test_to_sql_append(self, test_frame1):
707
711
assert sql .has_table ("test_frame4" , self .conn )
708
712
709
713
num_entries = 2 * len (test_frame1 )
710
- num_rows = self ._count_rows ( "test_frame4" )
714
+ num_rows = count_rows ( self .conn , "test_frame4" )
711
715
712
716
assert num_rows == num_entries
713
717
0 commit comments