11from __future__ import annotations
22
33import sqlite3
4+ from contextlib import closing
45from datetime import datetime , timezone
56from typing import TYPE_CHECKING , Any , NamedTuple
67
@@ -136,7 +137,7 @@ def test_write_to_sqlite(tmp_path: Path, count: int, db: Database) -> None:
136137 writer .write (record )
137138
138139 record_count = 0
139- with db .connector .connect (str (db_path )) as con :
140+ with closing ( db .connector .connect (str (db_path ) )) as con :
140141 cursor = con .execute ("SELECT COUNT(*) FROM 'test/record'" )
141142 record_count = cursor .fetchone ()[0 ]
142143
@@ -157,7 +158,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None:
157158 """Tests basic reading from a SQLite database."""
158159 # Generate a SQLite database
159160 db_path = tmp_path / "records.db"
160- with db .connector .connect (str (db_path )) as con :
161+ with closing ( db .connector .connect (str (db_path ) )) as con :
161162 con .execute (
162163 """
163164 CREATE TABLE 'test/record' (
@@ -176,6 +177,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None:
176177 """ ,
177178 (f"record{ i } " , f"foobar{ i } " .encode (), dt_isoformat , 3.14 + i ),
178179 )
180+ con .commit ()
179181
180182 # Read the SQLite database using flow.record
181183 with RecordReader (f"{ db .scheme } ://{ db_path } " ) as reader :
@@ -251,7 +253,7 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None:
251253 assert writer
252254
253255 # test if it's a valid database
254- with db .connector .connect (str (db_path )) as con :
256+ with closing ( db .connector .connect (str (db_path ) )) as con :
255257 assert con .execute ("SELECT * FROM sqlite_master" ).fetchall () == []
256258
257259
@@ -272,9 +274,10 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None:
272274def test_non_strict_sqlite_fields (tmp_path : Path , sqlite_coltype : str , sqlite_value : Any , expected_value : Any ) -> None :
273275 """SQLite by default is non strict, meaning that the value could be of different type than the column type."""
274276 db = tmp_path / "records.db"
275- with sqlite3 .connect (db ) as con :
277+ with closing ( sqlite3 .connect (db ) ) as con :
276278 con .execute (f"CREATE TABLE 'strict-test' (field { sqlite_coltype } )" )
277279 con .execute ("INSERT INTO 'strict-test' VALUES(?)" , (sqlite_value ,))
280+ con .commit ()
278281
279282 with RecordReader (f"sqlite://{ db } " ) as reader :
280283 record = next (iter (reader ))
@@ -294,10 +297,11 @@ def test_invalid_table_names_quoting(tmp_path: Path, invalid_table_name: str) ->
294297
295298 # Creating the tables with these invalid_table_names in SQLite is no problem
296299 db = tmp_path / "records.db"
297- with sqlite3 .connect (db ) as con :
300+ with closing ( sqlite3 .connect (db ) ) as con :
298301 con .execute (f"CREATE TABLE [{ invalid_table_name } ] (field TEXT, field2 TEXT)" )
299302 con .execute (f"INSERT INTO [{ invalid_table_name } ] VALUES(?, ?)" , ("hello" , "world" ))
300303 con .execute (f"INSERT INTO [{ invalid_table_name } ] VALUES(?, ?)" , ("goodbye" , "planet" ))
304+ con .commit ()
301305
302306 # However, these invalid_table_names should raise an exception when reading
303307 with (
@@ -320,10 +324,11 @@ def test_invalid_field_names_quoting(tmp_path: Path, invalid_field_name: str) ->
320324
321325 # Creating the table with invalid field name in SQLite is no problem
322326 db = tmp_path / "records.db"
323- with sqlite3 .connect (db ) as con :
327+ with closing ( sqlite3 .connect (db ) ) as con :
324328 con .execute (f"CREATE TABLE [test] (field TEXT, [{ invalid_field_name } ] TEXT)" )
325329 con .execute ("INSERT INTO [test] VALUES(?, ?)" , ("hello" , "world" ))
326330 con .execute ("INSERT INTO [test] VALUES(?, ?)" , ("goodbye" , "planet" ))
331+ con .commit ()
327332
328333 # However, these field names are invalid in flow.record and should raise an exception
329334 with (
@@ -365,7 +370,7 @@ def test_batch_size(
365370 writer .write (next (records ))
366371
367372 # test count of records in table (no flush yet if batch_size > 1)
368- with db .connector .connect (str (db_path )) as con :
373+ with closing ( db .connector .connect (str (db_path ) )) as con :
369374 x = con .execute ('SELECT COUNT(*) FROM "test/record"' )
370375 assert x .fetchone ()[0 ] is expected_first
371376
@@ -374,7 +379,7 @@ def test_batch_size(
374379 writer .write (next (records ))
375380
376381 # test count of records in table after flush
377- with db .connector .connect (str (db_path )) as con :
382+ with closing ( db .connector .connect (str (db_path ) )) as con :
378383 x = con .execute ('SELECT COUNT(*) FROM "test/record"' )
379384 assert x .fetchone ()[0 ] == expected_second
380385
0 commit comments