Skip to content

Commit 1650e5c

Browse files
committed
Fix ResourceWarning: unclosed database in SqliteReader and unit tests
- SqliteReader was missing a close() method - sqlite3.Connection as a context manager does NOT close the connection it only commits or rolls back the transaction.
1 parent c3f8cd8 commit 1650e5c

2 files changed

Lines changed: 18 additions & 8 deletions

File tree

flow/record/adapter/sqlite.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def __iter__(self) -> Iterator[Record]:
204204
if match_record_with_context(record, selector, ctx):
205205
yield record
206206

207+
def close(self) -> None:
208+
if self.con:
209+
self.con.close()
210+
self.con = None
211+
207212

208213
class SqliteWriter(AbstractWriter):
209214
"""SQLite writer."""

tests/adapter/test_sqlite_duckdb.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import sqlite3
4+
from contextlib import closing
45
from datetime import datetime, timezone
56
from typing import TYPE_CHECKING, Any, NamedTuple
67

@@ -136,7 +137,7 @@ def test_write_to_sqlite(tmp_path: Path, count: int, db: Database) -> None:
136137
writer.write(record)
137138

138139
record_count = 0
139-
with db.connector.connect(str(db_path)) as con:
140+
with closing(db.connector.connect(str(db_path))) as con:
140141
cursor = con.execute("SELECT COUNT(*) FROM 'test/record'")
141142
record_count = cursor.fetchone()[0]
142143

@@ -157,7 +158,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None:
157158
"""Tests basic reading from a SQLite database."""
158159
# Generate a SQLite database
159160
db_path = tmp_path / "records.db"
160-
with db.connector.connect(str(db_path)) as con:
161+
with closing(db.connector.connect(str(db_path))) as con:
161162
con.execute(
162163
"""
163164
CREATE TABLE 'test/record' (
@@ -176,6 +177,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None:
176177
""",
177178
(f"record{i}", f"foobar{i}".encode(), dt_isoformat, 3.14 + i),
178179
)
180+
con.commit()
179181

180182
# Read the SQLite database using flow.record
181183
with RecordReader(f"{db.scheme}://{db_path}") as reader:
@@ -251,7 +253,7 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None:
251253
assert writer
252254

253255
# test if it's a valid database
254-
with db.connector.connect(str(db_path)) as con:
256+
with closing(db.connector.connect(str(db_path))) as con:
255257
assert con.execute("SELECT * FROM sqlite_master").fetchall() == []
256258

257259

@@ -272,9 +274,10 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None:
272274
def test_non_strict_sqlite_fields(tmp_path: Path, sqlite_coltype: str, sqlite_value: Any, expected_value: Any) -> None:
273275
"""SQLite by default is non strict, meaning that the value could be of different type than the column type."""
274276
db = tmp_path / "records.db"
275-
with sqlite3.connect(db) as con:
277+
with closing(sqlite3.connect(db)) as con:
276278
con.execute(f"CREATE TABLE 'strict-test' (field {sqlite_coltype})")
277279
con.execute("INSERT INTO 'strict-test' VALUES(?)", (sqlite_value,))
280+
con.commit()
278281

279282
with RecordReader(f"sqlite://{db}") as reader:
280283
record = next(iter(reader))
@@ -294,10 +297,11 @@ def test_invalid_table_names_quoting(tmp_path: Path, invalid_table_name: str) ->
294297

295298
# Creating the tables with these invalid_table_names in SQLite is no problem
296299
db = tmp_path / "records.db"
297-
with sqlite3.connect(db) as con:
300+
with closing(sqlite3.connect(db)) as con:
298301
con.execute(f"CREATE TABLE [{invalid_table_name}] (field TEXT, field2 TEXT)")
299302
con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("hello", "world"))
300303
con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("goodbye", "planet"))
304+
con.commit()
301305

302306
# However, these invalid_table_names should raise an exception when reading
303307
with (
@@ -320,10 +324,11 @@ def test_invalid_field_names_quoting(tmp_path: Path, invalid_field_name: str) ->
320324

321325
# Creating the table with invalid field name in SQLite is no problem
322326
db = tmp_path / "records.db"
323-
with sqlite3.connect(db) as con:
327+
with closing(sqlite3.connect(db)) as con:
324328
con.execute(f"CREATE TABLE [test] (field TEXT, [{invalid_field_name}] TEXT)")
325329
con.execute("INSERT INTO [test] VALUES(?, ?)", ("hello", "world"))
326330
con.execute("INSERT INTO [test] VALUES(?, ?)", ("goodbye", "planet"))
331+
con.commit()
327332

328333
# However, these field names are invalid in flow.record and should raise an exception
329334
with (
@@ -365,7 +370,7 @@ def test_batch_size(
365370
writer.write(next(records))
366371

367372
# test count of records in table (no flush yet if batch_size > 1)
368-
with db.connector.connect(str(db_path)) as con:
373+
with closing(db.connector.connect(str(db_path))) as con:
369374
x = con.execute('SELECT COUNT(*) FROM "test/record"')
370375
assert x.fetchone()[0] is expected_first
371376

@@ -374,7 +379,7 @@ def test_batch_size(
374379
writer.write(next(records))
375380

376381
# test count of records in table after flush
377-
with db.connector.connect(str(db_path)) as con:
382+
with closing(db.connector.connect(str(db_path))) as con:
378383
x = con.execute('SELECT COUNT(*) FROM "test/record"')
379384
assert x.fetchone()[0] == expected_second
380385

0 commit comments

Comments
 (0)