gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)

- Use memory_database() helper
- Move test utility functions to util.py
- Add convenience memory database mixin
- Add check() helper for closed connection tests
This commit is contained in:
Erlend E. Aasland 2023-08-17 08:45:48 +02:00 committed by GitHub
parent c9d83f93d8
commit 1344cfac43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 371 additions and 385 deletions

View File

@ -1,6 +1,8 @@
import sqlite3 as sqlite
import unittest
from .util import memory_database
class BackupTests(unittest.TestCase):
def setUp(self):
@ -32,32 +34,32 @@ class BackupTests(unittest.TestCase):
self.cx.backup(self.cx)
def test_bad_target_closed_connection(self):
bck = sqlite.connect(':memory:')
bck.close()
with self.assertRaises(sqlite.ProgrammingError):
self.cx.backup(bck)
with memory_database() as bck:
bck.close()
with self.assertRaises(sqlite.ProgrammingError):
self.cx.backup(bck)
def test_bad_source_closed_connection(self):
bck = sqlite.connect(':memory:')
source = sqlite.connect(":memory:")
source.close()
with self.assertRaises(sqlite.ProgrammingError):
source.backup(bck)
with memory_database() as bck:
source = sqlite.connect(":memory:")
source.close()
with self.assertRaises(sqlite.ProgrammingError):
source.backup(bck)
def test_bad_target_in_transaction(self):
bck = sqlite.connect(':memory:')
bck.execute('CREATE TABLE bar (key INTEGER)')
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
with self.assertRaises(sqlite.OperationalError) as cm:
self.cx.backup(bck)
with memory_database() as bck:
bck.execute('CREATE TABLE bar (key INTEGER)')
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
with self.assertRaises(sqlite.OperationalError) as cm:
self.cx.backup(bck)
def test_keyword_only_args(self):
with self.assertRaises(TypeError):
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, 1)
def test_simple(self):
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck)
self.verify_backup(bck)
@ -67,7 +69,7 @@ class BackupTests(unittest.TestCase):
def progress(status, remaining, total):
journal.append(status)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck)
@ -81,7 +83,7 @@ class BackupTests(unittest.TestCase):
def progress(status, remaining, total):
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, progress=progress)
self.verify_backup(bck)
@ -94,7 +96,7 @@ class BackupTests(unittest.TestCase):
def progress(status, remaining, total):
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=-1, progress=progress)
self.verify_backup(bck)
@ -103,7 +105,7 @@ class BackupTests(unittest.TestCase):
def test_non_callable_progress(self):
with self.assertRaises(TypeError) as cm:
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=1, progress='bar')
self.assertEqual(str(cm.exception), 'progress argument must be a callable')
@ -116,7 +118,7 @@ class BackupTests(unittest.TestCase):
self.cx.commit()
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck)
@ -140,12 +142,12 @@ class BackupTests(unittest.TestCase):
self.assertEqual(str(err.exception), 'nearly out of space')
def test_database_source_name(self):
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='main')
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='temp')
with self.assertRaises(sqlite.OperationalError) as cm:
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='non-existing')
self.assertIn("unknown database", str(cm.exception))
@ -153,7 +155,7 @@ class BackupTests(unittest.TestCase):
self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)')
self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)])
self.cx.commit()
with sqlite.connect(':memory:') as bck:
with memory_database() as bck:
self.cx.backup(bck, name='attached_db')
self.verify_backup(bck)

View File

@ -33,26 +33,13 @@ from test.support import (
SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess,
is_emscripten, is_wasi
)
from test.support import gc_collect
from test.support import threading_helper
from _testcapi import INT_MAX, ULLONG_MAX
from os import SEEK_SET, SEEK_CUR, SEEK_END
from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
from .util import memory_database, cx_limit
class ModuleTests(unittest.TestCase):
@ -326,9 +313,9 @@ class ModuleTests(unittest.TestCase):
self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK")
def test_disallow_instantiation(self):
cx = sqlite.connect(":memory:")
check_disallow_instantiation(self, type(cx("select 1")))
check_disallow_instantiation(self, sqlite.Blob)
with memory_database() as cx:
check_disallow_instantiation(self, type(cx("select 1")))
check_disallow_instantiation(self, sqlite.Blob)
def test_complete_statement(self):
self.assertFalse(sqlite.complete_statement("select t"))
@ -342,6 +329,7 @@ class ConnectionTests(unittest.TestCase):
cu = self.cx.cursor()
cu.execute("create table test(id integer primary key, name text)")
cu.execute("insert into test(name) values (?)", ("foo",))
cu.close()
def tearDown(self):
self.cx.close()
@ -412,21 +400,22 @@ class ConnectionTests(unittest.TestCase):
def test_in_transaction(self):
# Can't use db from setUp because we want to test initial state.
cx = sqlite.connect(":memory:")
cu = cx.cursor()
self.assertEqual(cx.in_transaction, False)
cu.execute("create table transactiontest(id integer primary key, name text)")
self.assertEqual(cx.in_transaction, False)
cu.execute("insert into transactiontest(name) values (?)", ("foo",))
self.assertEqual(cx.in_transaction, True)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, True)
cx.commit()
self.assertEqual(cx.in_transaction, False)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, False)
with memory_database() as cx:
cu = cx.cursor()
self.assertEqual(cx.in_transaction, False)
cu.execute("create table transactiontest(id integer primary key, name text)")
self.assertEqual(cx.in_transaction, False)
cu.execute("insert into transactiontest(name) values (?)", ("foo",))
self.assertEqual(cx.in_transaction, True)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, True)
cx.commit()
self.assertEqual(cx.in_transaction, False)
cu.execute("select name from transactiontest where name=?", ["foo"])
row = cu.fetchone()
self.assertEqual(cx.in_transaction, False)
cu.close()
def test_in_transaction_ro(self):
with self.assertRaises(AttributeError):
@ -450,10 +439,9 @@ class ConnectionTests(unittest.TestCase):
self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
def test_interrupt_on_closed_db(self):
cx = sqlite.connect(":memory:")
cx.close()
self.cx.close()
with self.assertRaises(sqlite.ProgrammingError):
cx.interrupt()
self.cx.interrupt()
def test_interrupt(self):
self.assertIsNone(self.cx.interrupt())
@ -521,29 +509,29 @@ class ConnectionTests(unittest.TestCase):
self.assertEqual(cx.isolation_level, level)
def test_connection_reinit(self):
db = ":memory:"
cx = sqlite.connect(db)
cx.text_factory = bytes
cx.row_factory = sqlite.Row
cu = cx.cursor()
cu.execute("create table foo (bar)")
cu.executemany("insert into foo (bar) values (?)",
((str(v),) for v in range(4)))
cu.execute("select bar from foo")
with memory_database() as cx:
cx.text_factory = bytes
cx.row_factory = sqlite.Row
cu = cx.cursor()
cu.execute("create table foo (bar)")
cu.executemany("insert into foo (bar) values (?)",
((str(v),) for v in range(4)))
cu.execute("select bar from foo")
rows = [r for r in cu.fetchmany(2)]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
rows = [r for r in cu.fetchmany(2)]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
cx.__init__(db)
cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d")))
cx.__init__(":memory:")
cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d")))
# This uses the old database, old row factory, but new text factory
rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])
# This uses the old database, old row factory, but new text factory
rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])
cu.close()
def test_connection_bad_reinit(self):
cx = sqlite.connect(":memory:")
@ -591,11 +579,11 @@ class ConnectionTests(unittest.TestCase):
"parameters in Python 3.15."
)
with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
sqlite.connect(":memory:", 1.0)
cx = sqlite.connect(":memory:", 1.0)
cx.close()
self.assertEqual(cm.filename, __file__)
class UninitialisedConnectionTests(unittest.TestCase):
def setUp(self):
self.cx = sqlite.Connection.__new__(sqlite.Connection)
@ -1571,12 +1559,12 @@ class ThreadTests(unittest.TestCase):
except sqlite.Error:
err.append("multi-threading not allowed")
con = sqlite.connect(":memory:", check_same_thread=False)
err = []
t = threading.Thread(target=run, kwargs={"con": con, "err": err})
t.start()
t.join()
self.assertEqual(len(err), 0, "\n".join(err))
with memory_database(check_same_thread=False) as con:
err = []
t = threading.Thread(target=run, kwargs={"con": con, "err": err})
t.start()
t.join()
self.assertEqual(len(err), 0, "\n".join(err))
class ConstructorTests(unittest.TestCase):
@ -1602,9 +1590,16 @@ class ConstructorTests(unittest.TestCase):
b = sqlite.Binary(b"\0'")
class ExtensionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
def test_script_string_sql(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
cur = self.cur
cur.executescript("""
-- bla bla
/* a stupid comment */
@ -1616,40 +1611,40 @@ class ExtensionTests(unittest.TestCase):
self.assertEqual(res, 5)
def test_script_syntax_error(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(sqlite.OperationalError):
cur.executescript("create table test(x); asdf; create table test2(x)")
self.cur.executescript("""
CREATE TABLE test(x);
asdf;
CREATE TABLE test2(x)
""")
def test_script_error_normal(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(sqlite.OperationalError):
cur.executescript("create table test(sadfsadfdsa); select foo from hurz;")
self.cur.executescript("""
CREATE TABLE test(sadfsadfdsa);
SELECT foo FROM hurz;
""")
def test_cursor_executescript_as_bytes(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(TypeError):
cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
self.cur.executescript(b"""
CREATE TABLE test(foo);
INSERT INTO test(foo) VALUES (5);
""")
def test_cursor_executescript_with_null_characters(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError):
cur.executescript("""
create table a(i);\0
insert into a(i) values (5);
""")
self.cur.executescript("""
CREATE TABLE a(i);\0
INSERT INTO a(i) VALUES (5);
""")
def test_cursor_executescript_with_surrogates(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(UnicodeEncodeError):
cur.executescript("""
create table a(s);
insert into a(s) values ('\ud8ff');
""")
self.cur.executescript("""
CREATE TABLE a(s);
INSERT INTO a(s) VALUES ('\ud8ff');
""")
def test_cursor_executescript_too_large_script(self):
msg = "query string is too large"
@ -1659,19 +1654,18 @@ class ExtensionTests(unittest.TestCase):
cx.executescript("select 'too large'".ljust(lim+1))
def test_cursor_executescript_tx_control(self):
con = sqlite.connect(":memory:")
con = self.con
con.execute("begin")
self.assertTrue(con.in_transaction)
con.executescript("select 1")
self.assertFalse(con.in_transaction)
def test_connection_execute(self):
con = sqlite.connect(":memory:")
result = con.execute("select 5").fetchone()[0]
result = self.con.execute("select 5").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.execute")
def test_connection_executemany(self):
con = sqlite.connect(":memory:")
con = self.con
con.execute("create table test(foo)")
con.executemany("insert into test(foo) values (?)", [(3,), (4,)])
result = con.execute("select foo from test order by foo").fetchall()
@ -1679,47 +1673,44 @@ class ExtensionTests(unittest.TestCase):
self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")
def test_connection_executescript(self):
con = sqlite.connect(":memory:")
con.executescript("create table test(foo); insert into test(foo) values (5);")
con = self.con
con.executescript("""
CREATE TABLE test(foo);
INSERT INTO test(foo) VALUES (5);
""")
result = con.execute("select foo from test").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.executescript")
class ClosedConTests(unittest.TestCase):
def check(self, fn, *args, **kwds):
regex = "Cannot operate on a closed database."
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
fn(*args, **kwds)
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
self.con.close()
def test_closed_con_cursor(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
cur = con.cursor()
self.check(self.con.cursor)
def test_closed_con_commit(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con.commit()
self.check(self.con.commit)
def test_closed_con_rollback(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con.rollback()
self.check(self.con.rollback)
def test_closed_cur_execute(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
con.close()
with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4")
self.check(self.cur.execute, "select 4")
def test_closed_create_function(self):
con = sqlite.connect(":memory:")
con.close()
def f(x): return 17
with self.assertRaises(sqlite.ProgrammingError):
con.create_function("foo", 1, f)
def f(x):
return 17
self.check(self.con.create_function, "foo", 1, f)
def test_closed_create_aggregate(self):
con = sqlite.connect(":memory:")
con.close()
class Agg:
def __init__(self):
pass
@ -1727,29 +1718,21 @@ class ClosedConTests(unittest.TestCase):
pass
def finalize(self):
return 17
with self.assertRaises(sqlite.ProgrammingError):
con.create_aggregate("foo", 1, Agg)
self.check(self.con.create_aggregate, "foo", 1, Agg)
def test_closed_set_authorizer(self):
con = sqlite.connect(":memory:")
con.close()
def authorizer(*args):
return sqlite.DENY
with self.assertRaises(sqlite.ProgrammingError):
con.set_authorizer(authorizer)
self.check(self.con.set_authorizer, authorizer)
def test_closed_set_progress_callback(self):
con = sqlite.connect(":memory:")
con.close()
def progress(): pass
with self.assertRaises(sqlite.ProgrammingError):
con.set_progress_handler(progress, 100)
def progress():
pass
self.check(self.con.set_progress_handler, progress, 100)
def test_closed_call(self):
con = sqlite.connect(":memory:")
con.close()
with self.assertRaises(sqlite.ProgrammingError):
con()
self.check(self.con)
class ClosedCurTests(unittest.TestCase):
def test_closed(self):

View File

@ -2,16 +2,12 @@
import unittest
import sqlite3 as sqlite
from .test_dbapi import memory_database
from .util import memory_database
from .util import MemoryDatabaseMixin
class DumpTests(unittest.TestCase):
def setUp(self):
self.cx = sqlite.connect(":memory:")
self.cu = self.cx.cursor()
def tearDown(self):
self.cx.close()
class DumpTests(MemoryDatabaseMixin, unittest.TestCase):
def test_table_dump(self):
expected_sqls = [

View File

@ -24,6 +24,9 @@ import unittest
import sqlite3 as sqlite
from collections.abc import Sequence
from .util import memory_database
from .util import MemoryDatabaseMixin
def dict_factory(cursor, row):
d = {}
@ -45,10 +48,12 @@ class ConnectionFactoryTests(unittest.TestCase):
def __init__(self, *args, **kwargs):
sqlite.Connection.__init__(self, *args, **kwargs)
for factory in DefectFactory, OkFactory:
with self.subTest(factory=factory):
con = sqlite.connect(":memory:", factory=factory)
self.assertIsInstance(con, factory)
with memory_database(factory=OkFactory) as con:
self.assertIsInstance(con, OkFactory)
regex = "Base Connection.__init__ not called."
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
with memory_database(factory=DefectFactory) as con:
self.assertIsInstance(con, DefectFactory)
def test_connection_factory_relayed_call(self):
# gh-95132: keyword args must not be passed as positional args
@ -57,9 +62,9 @@ class ConnectionFactoryTests(unittest.TestCase):
kwargs["isolation_level"] = None
super(Factory, self).__init__(*args, **kwargs)
con = sqlite.connect(":memory:", factory=Factory)
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
with memory_database(factory=Factory) as con:
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
def test_connection_factory_as_positional_arg(self):
class Factory(sqlite.Connection):
@ -74,18 +79,13 @@ class ConnectionFactoryTests(unittest.TestCase):
r"parameters in Python 3.15."
)
with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory)
with memory_database(5.0, 0, None, True, Factory) as con:
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
self.assertEqual(cm.filename, __file__)
self.assertIsNone(con.isolation_level)
self.assertIsInstance(con, Factory)
class CursorFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def test_is_instance(self):
cur = self.con.cursor()
@ -103,9 +103,8 @@ class CursorFactoryTests(unittest.TestCase):
# invalid callable returning non-cursor
self.assertRaises(TypeError, self.con.cursor, lambda con: None)
class RowFactoryTestsBackwardsCompat(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase):
def test_is_produced_by_factory(self):
cur = self.con.cursor(factory=MyCursor)
@ -114,12 +113,8 @@ class RowFactoryTestsBackwardsCompat(unittest.TestCase):
self.assertIsInstance(row, dict)
cur.close()
def tearDown(self):
self.con.close()
class RowFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def test_custom_factory(self):
self.con.row_factory = lambda cur, row: list(row)
@ -265,12 +260,8 @@ class RowFactoryTests(unittest.TestCase):
self.assertRaises(TypeError, self.con.cursor, FakeCursor)
self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
def tearDown(self):
self.con.close()
class TextFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
def test_unicode(self):
austria = "Österreich"
@ -291,15 +282,17 @@ class TextFactoryTests(unittest.TestCase):
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
self.assertTrue(row[0].endswith("reich"), "column must contain original data")
def tearDown(self):
self.con.close()
class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.execute("create table test (value text)")
self.con.execute("insert into test (value) values (?)", ("a\x00b",))
def tearDown(self):
self.con.close()
def test_string(self):
# text_factory defaults to str
row = self.con.execute("select value from test").fetchone()
@ -325,9 +318,6 @@ class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b")
def tearDown(self):
self.con.close()
if __name__ == "__main__":
unittest.main()

View File

@ -26,34 +26,31 @@ import unittest
from test.support.os_helper import TESTFN, unlink
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from test.test_sqlite3.test_userfunctions import with_tracebacks
from .util import memory_database, cx_limit, with_tracebacks
from .util import MemoryDatabaseMixin
class CollationTests(unittest.TestCase):
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
def test_create_collation_not_string(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError):
con.create_collation(None, lambda x, y: (x > y) - (x < y))
self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
def test_create_collation_not_callable(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm:
con.create_collation("X", 42)
self.con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable')
def test_create_collation_not_ascii(self):
con = sqlite.connect(":memory:")
con.create_collation("collä", lambda x, y: (x > y) - (x < y))
self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def test_create_collation_bad_upper(self):
class BadUpperStr(str):
def upper(self):
return None
con = sqlite.connect(":memory:")
mycoll = lambda x, y: -((x > y) - (x < y))
con.create_collation(BadUpperStr("mycoll"), mycoll)
result = con.execute("""
self.con.create_collation(BadUpperStr("mycoll"), mycoll)
result = self.con.execute("""
select x from (
select 'a' as x
union
@ -68,8 +65,7 @@ class CollationTests(unittest.TestCase):
# reverse order
return -((x > y) - (x < y))
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@ -79,21 +75,20 @@ class CollationTests(unittest.TestCase):
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg='the expected order was not returned')
con.create_collation("mycoll", None)
self.con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def test_collation_returns_large_integer(self):
def mycoll(x, y):
# reverse order
return -((x > y) - (x < y)) * 2**32
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@ -103,7 +98,7 @@ class CollationTests(unittest.TestCase):
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned")
@ -112,7 +107,7 @@ class CollationTests(unittest.TestCase):
Register two different collation functions under the same name.
Verify that the last one is actually used.
"""
con = sqlite.connect(":memory:")
con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
result = con.execute("""
@ -126,25 +121,26 @@ class CollationTests(unittest.TestCase):
Register a collation, then deregister it. Make sure an error is raised if we try
to use it.
"""
con = sqlite.connect(":memory:")
con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class ProgressTests(unittest.TestCase):
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
def test_progress_handler_used(self):
"""
Test that the progress handler is invoked once it is set.
"""
con = sqlite.connect(":memory:")
progress_calls = []
def progress():
progress_calls.append(None)
return 0
con.set_progress_handler(progress, 1)
con.execute("""
self.con.set_progress_handler(progress, 1)
self.con.execute("""
create table foo(a, b)
""")
self.assertTrue(progress_calls)
@ -153,7 +149,7 @@ class ProgressTests(unittest.TestCase):
"""
Test that the opcode argument is respected.
"""
con = sqlite.connect(":memory:")
con = self.con
progress_calls = []
def progress():
progress_calls.append(None)
@ -176,11 +172,10 @@ class ProgressTests(unittest.TestCase):
"""
Test that returning a non-zero value stops the operation in progress.
"""
con = sqlite.connect(":memory:")
def progress():
return 1
con.set_progress_handler(progress, 1)
curs = con.cursor()
self.con.set_progress_handler(progress, 1)
curs = self.con.cursor()
self.assertRaises(
sqlite.OperationalError,
curs.execute,
@ -190,7 +185,7 @@ class ProgressTests(unittest.TestCase):
"""
Test that setting the progress handler to None clears the previously set handler.
"""
con = sqlite.connect(":memory:")
con = self.con
action = 0
def progress():
nonlocal action
@ -203,31 +198,30 @@ class ProgressTests(unittest.TestCase):
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
con.set_progress_handler(bad_progress, 1)
self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
self.con.execute("""
create table foo(a, b)
""")
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
con.set_progress_handler(bad_progress, 1)
self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
self.con.execute("""
create table foo(a, b)
""")
class TraceCallbackTests(unittest.TestCase):
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
@ -242,12 +236,11 @@ class TraceCallbackTests(unittest.TestCase):
"""
Test that the trace callback is invoked once it is set.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
self.con.set_trace_callback(trace)
self.con.execute("create table foo(a, b)")
self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
@ -255,7 +248,7 @@ class TraceCallbackTests(unittest.TestCase):
"""
Test that setting the trace callback to None clears the previously set callback.
"""
con = sqlite.connect(":memory:")
con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)
@ -269,7 +262,7 @@ class TraceCallbackTests(unittest.TestCase):
Test that the statement can contain unicode literals.
"""
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
con = sqlite.connect(":memory:")
con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)

View File

@ -28,15 +28,12 @@ import functools
from test import support
from unittest.mock import patch
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from .util import memory_database, cx_limit
from .util import MemoryDatabaseMixin
class RegressionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
class RegressionTests(MemoryDatabaseMixin, unittest.TestCase):
def test_pragma_user_version(self):
# This used to crash pysqlite because this pragma command returns NULL for the column name
@ -45,28 +42,24 @@ class RegressionTests(unittest.TestCase):
def test_pragma_schema_version(self):
# This still crashed pysqlite <= 2.2.1
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
try:
with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con:
cur = self.con.cursor()
cur.execute("pragma schema_version")
finally:
cur.close()
con.close()
def test_statement_reset(self):
# pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
# reset before a rollback, but only those that are still in the
# statement cache. The others are not accessible from the connection object.
con = sqlite.connect(":memory:", cached_statements=5)
cursors = [con.cursor() for x in range(5)]
cursors[0].execute("create table test(x)")
for i in range(10):
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
with memory_database(cached_statements=5) as con:
cursors = [con.cursor() for x in range(5)]
cursors[0].execute("create table test(x)")
for i in range(10):
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
for i in range(5):
cursors[i].execute(" " * i + "select x from test")
for i in range(5):
cursors[i].execute(" " * i + "select x from test")
con.rollback()
con.rollback()
def test_column_name_with_spaces(self):
cur = self.con.cursor()
@ -81,17 +74,15 @@ class RegressionTests(unittest.TestCase):
# cache when closing the database. statements that were still
# referenced in cursors weren't closed and could provoke "
# "OperationalError: Unable to close due to unfinalised statements".
con = sqlite.connect(":memory:")
cursors = []
# default statement cache size is 100
for i in range(105):
cur = con.cursor()
cur = self.con.cursor()
cursors.append(cur)
cur.execute("select 1 x union select " + str(i))
con.close()
def test_on_conflict_rollback(self):
con = sqlite.connect(":memory:")
con = self.con
con.execute("create table foo(x, unique(x) on conflict rollback)")
con.execute("insert into foo(x) values (1)")
try:
@ -126,16 +117,16 @@ class RegressionTests(unittest.TestCase):
a statement. This test exhibits the problem.
"""
SELECT = "select * from foo"
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
cur = con.cursor()
cur.execute("create table foo(bar timestamp)")
with self.assertWarnsRegex(DeprecationWarning, "adapter"):
cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
cur.execute(SELECT)
cur.execute("drop table foo")
cur.execute("create table foo(bar integer)")
cur.execute("insert into foo(bar) values (5)")
cur.execute(SELECT)
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
cur = con.cursor()
cur.execute("create table foo(bar timestamp)")
with self.assertWarnsRegex(DeprecationWarning, "adapter"):
cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
cur.execute(SELECT)
cur.execute("drop table foo")
cur.execute("create table foo(bar integer)")
cur.execute("insert into foo(bar) values (5)")
cur.execute(SELECT)
def test_bind_mutating_list(self):
# Issue41662: Crash when mutate a list of parameters during iteration.
@ -144,11 +135,11 @@ class RegressionTests(unittest.TestCase):
parameters.clear()
return "..."
parameters = [X(), 0]
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
con.execute("create table foo(bar X, baz integer)")
# Should not crash
with self.assertRaises(IndexError):
con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
con.execute("create table foo(bar X, baz integer)")
# Should not crash
with self.assertRaises(IndexError):
con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
def test_error_msg_decode_error(self):
# When porting the module to Python 3.0, the error message about
@ -173,7 +164,7 @@ class RegressionTests(unittest.TestCase):
def __del__(self):
con.isolation_level = ""
con = sqlite.connect(":memory:")
con = self.con
con.isolation_level = None
for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE":
with self.subTest(level=level):
@ -204,8 +195,7 @@ class RegressionTests(unittest.TestCase):
def __init__(self, con):
pass
con = sqlite.connect(":memory:")
cur = Cursor(con)
cur = Cursor(self.con)
with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4+5").fetchall()
with self.assertRaisesRegex(sqlite.ProgrammingError,
@ -238,7 +228,9 @@ class RegressionTests(unittest.TestCase):
2.5.3 introduced a regression so that these could no longer
be created.
"""
con = sqlite.connect(":memory:", isolation_level=None)
with memory_database(isolation_level=None) as con:
self.assertIsNone(con.isolation_level)
self.assertFalse(con.in_transaction)
def test_pragma_autocommit(self):
"""
@ -273,9 +265,7 @@ class RegressionTests(unittest.TestCase):
Recursively using a cursor, such as when reusing it from a generator led to segfaults.
Now we catch recursive cursor usage and raise a ProgrammingError.
"""
con = sqlite.connect(":memory:")
cur = con.cursor()
cur = self.con.cursor()
cur.execute("create table a (bar)")
cur.execute("create table b (baz)")
@ -295,29 +285,30 @@ class RegressionTests(unittest.TestCase):
since the microsecond string "456" actually represents "456000".
"""
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
cur = con.cursor()
cur.execute("CREATE TABLE t (x TIMESTAMP)")
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
cur = con.cursor()
cur.execute("CREATE TABLE t (x TIMESTAMP)")
# Microseconds should be 456000
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
# Microseconds should be 456000
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
# Microseconds should be truncated to 123456
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
# Microseconds should be truncated to 123456
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
cur.execute("SELECT * FROM t")
with self.assertWarnsRegex(DeprecationWarning, "converter"):
values = [x[0] for x in cur.fetchall()]
cur.execute("SELECT * FROM t")
with self.assertWarnsRegex(DeprecationWarning, "converter"):
values = [x[0] for x in cur.fetchall()]
self.assertEqual(values, [
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
])
self.assertEqual(values, [
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
])
def test_invalid_isolation_level_type(self):
# isolation level is a string, not an integer
self.assertRaises(TypeError,
sqlite.connect, ":memory:", isolation_level=123)
regex = "isolation_level must be str or None"
with self.assertRaisesRegex(TypeError, regex):
memory_database(isolation_level=123).__enter__()
def test_null_character(self):
@ -333,7 +324,7 @@ class RegressionTests(unittest.TestCase):
cur.execute, query)
def test_surrogates(self):
con = sqlite.connect(":memory:")
con = self.con
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
cur = con.cursor()
@ -359,7 +350,7 @@ class RegressionTests(unittest.TestCase):
to return rows multiple times when fetched from cursors
after commit. See issues 10513 and 23129 for details.
"""
con = sqlite.connect(":memory:")
con = self.con
con.executescript("""
create table t(c);
create table t2(c);
@ -391,10 +382,9 @@ class RegressionTests(unittest.TestCase):
"""
def callback(*args):
pass
con = sqlite.connect(":memory:")
cur = sqlite.Cursor(con)
cur = sqlite.Cursor(self.con)
ref = weakref.ref(cur, callback)
cur.__init__(con)
cur.__init__(self.con)
del cur
# The interpreter shouldn't crash when ref is collected.
del ref
@ -425,6 +415,7 @@ class RegressionTests(unittest.TestCase):
def test_table_lock_cursor_replace_stmt(self):
with memory_database() as con:
con = self.con
cur = con.cursor()
cur.execute("create table t(t)")
cur.executemany("insert into t values(?)",

View File

@ -28,7 +28,8 @@ from test.support import LOOPBACK_TIMEOUT
from test.support.os_helper import TESTFN, unlink
from test.support.script_helper import assert_python_ok
from test.test_sqlite3.test_dbapi import memory_database
from .util import memory_database
from .util import MemoryDatabaseMixin
TIMEOUT = LOOPBACK_TIMEOUT / 10
@ -132,14 +133,14 @@ class TransactionTests(unittest.TestCase):
def test_rollback_cursor_consistency(self):
"""Check that cursors behave correctly after rollback."""
con = sqlite.connect(":memory:")
cur = con.cursor()
cur.execute("create table test(x)")
cur.execute("insert into test(x) values (5)")
cur.execute("select 1 union select 2 union select 3")
with memory_database() as con:
cur = con.cursor()
cur.execute("create table test(x)")
cur.execute("insert into test(x) values (5)")
cur.execute("select 1 union select 2 union select 3")
con.rollback()
self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
con.rollback()
self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
def test_multiple_cursors_and_iternext(self):
# gh-94028: statements are cleared and reset in cursor iternext.
@ -218,10 +219,7 @@ class RollbackTests(unittest.TestCase):
class SpecialCommandTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase):
def test_drop_table(self):
self.cur.execute("create table test(i)")
@ -233,14 +231,8 @@ class SpecialCommandTests(unittest.TestCase):
self.cur.execute("insert into test(i) values (5)")
self.cur.execute("pragma count_changes=1")
def tearDown(self):
self.cur.close()
self.con.close()
class TransactionalDDL(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase):
def test_ddl_does_not_autostart_transaction(self):
# For backwards compatibility reasons, DDL statements should not
@ -268,9 +260,6 @@ class TransactionalDDL(unittest.TestCase):
with self.assertRaises(sqlite.OperationalError):
self.con.execute("select * from test")
def tearDown(self):
self.con.close()
class IsolationLevelFromInit(unittest.TestCase):
CREATE = "create table t(t)"

View File

@ -21,54 +21,15 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import contextlib
import functools
import io
import re
import sys
import unittest
import sqlite3 as sqlite
from unittest.mock import Mock, patch
from test.support import bigmemtest, catch_unraisable_exception, gc_collect
from test.support import bigmemtest, gc_collect
from test.test_sqlite3.test_dbapi import cx_limit
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite.enable_callback_tracebacks(False)
from .util import cx_limit, memory_database
from .util import with_tracebacks, check_tracebacks
def func_returntext():
@ -405,19 +366,19 @@ class FunctionTests(unittest.TestCase):
def test_function_destructor_via_gc(self):
# See bpo-44304: The destructor of the user function can
# crash if is called without the GIL from the gc functions
dest = sqlite.connect(':memory:')
def md5sum(t):
return
dest.create_function("md5", 1, md5sum)
x = dest("create table lang (name, first_appeared)")
del md5sum, dest
with memory_database() as dest:
dest.create_function("md5", 1, md5sum)
x = dest("create table lang (name, first_appeared)")
del md5sum, dest
y = [x]
y.append(y)
y = [x]
y.append(y)
del x,y
gc_collect()
del x,y
gc_collect()
@with_tracebacks(OverflowError)
def test_func_return_too_large_int(self):
@ -514,6 +475,10 @@ class WindowFunctionTests(unittest.TestCase):
"""
self.con.create_window_function("sumint", 1, WindowSumInt)
def tearDown(self):
self.cur.close()
self.con.close()
def test_win_sum_int(self):
self.cur.execute(self.query % "sumint")
self.assertEqual(self.cur.fetchall(), self.expected)
@ -634,6 +599,7 @@ class AggregateTests(unittest.TestCase):
""")
cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
("foo", 5, 3.14, None, memoryview(b"blob"),))
cur.close()
self.con.create_aggregate("nostep", 1, AggrNoStep)
self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
@ -646,9 +612,7 @@ class AggregateTests(unittest.TestCase):
self.con.create_aggregate("aggtxt", 1, AggrText)
def tearDown(self):
#self.cur.close()
#self.con.close()
pass
self.con.close()
def test_aggr_error_on_create(self):
with self.assertRaises(sqlite.OperationalError):
@ -775,7 +739,7 @@ class AuthorizerTests(unittest.TestCase):
self.con.set_authorizer(self.authorizer_cb)
def tearDown(self):
pass
self.con.close()
def test_table_access(self):
with self.assertRaises(sqlite.DatabaseError) as cm:

View File

@ -0,0 +1,78 @@
import contextlib
import functools
import io
import re
import sqlite3
import test.support
import unittest
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite3.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with test.support.catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite3.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite3.enable_callback_tracebacks(False)
class MemoryDatabaseMixin:
def setUp(self):
self.con = sqlite3.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
@property
def cx(self):
return self.con
@property
def cu(self):
return self.cur