mirror of
https://github.com/python/cpython.git
synced 2024-11-28 04:15:11 +08:00
gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)
- Use memory_database() helper - Move test utility functions to util.py - Add convenience memory database mixin - Add check() helper for closed connection tests
This commit is contained in:
parent
c9d83f93d8
commit
1344cfac43
@ -1,6 +1,8 @@
|
||||
import sqlite3 as sqlite
|
||||
import unittest
|
||||
|
||||
from .util import memory_database
|
||||
|
||||
|
||||
class BackupTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -32,32 +34,32 @@ class BackupTests(unittest.TestCase):
|
||||
self.cx.backup(self.cx)
|
||||
|
||||
def test_bad_target_closed_connection(self):
|
||||
bck = sqlite.connect(':memory:')
|
||||
bck.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
self.cx.backup(bck)
|
||||
with memory_database() as bck:
|
||||
bck.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
self.cx.backup(bck)
|
||||
|
||||
def test_bad_source_closed_connection(self):
|
||||
bck = sqlite.connect(':memory:')
|
||||
source = sqlite.connect(":memory:")
|
||||
source.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
source.backup(bck)
|
||||
with memory_database() as bck:
|
||||
source = sqlite.connect(":memory:")
|
||||
source.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
source.backup(bck)
|
||||
|
||||
def test_bad_target_in_transaction(self):
|
||||
bck = sqlite.connect(':memory:')
|
||||
bck.execute('CREATE TABLE bar (key INTEGER)')
|
||||
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
|
||||
with self.assertRaises(sqlite.OperationalError) as cm:
|
||||
self.cx.backup(bck)
|
||||
with memory_database() as bck:
|
||||
bck.execute('CREATE TABLE bar (key INTEGER)')
|
||||
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
|
||||
with self.assertRaises(sqlite.OperationalError) as cm:
|
||||
self.cx.backup(bck)
|
||||
|
||||
def test_keyword_only_args(self):
|
||||
with self.assertRaises(TypeError):
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, 1)
|
||||
|
||||
def test_simple(self):
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck)
|
||||
self.verify_backup(bck)
|
||||
|
||||
@ -67,7 +69,7 @@ class BackupTests(unittest.TestCase):
|
||||
def progress(status, remaining, total):
|
||||
journal.append(status)
|
||||
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, pages=1, progress=progress)
|
||||
self.verify_backup(bck)
|
||||
|
||||
@ -81,7 +83,7 @@ class BackupTests(unittest.TestCase):
|
||||
def progress(status, remaining, total):
|
||||
journal.append(remaining)
|
||||
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, progress=progress)
|
||||
self.verify_backup(bck)
|
||||
|
||||
@ -94,7 +96,7 @@ class BackupTests(unittest.TestCase):
|
||||
def progress(status, remaining, total):
|
||||
journal.append(remaining)
|
||||
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, pages=-1, progress=progress)
|
||||
self.verify_backup(bck)
|
||||
|
||||
@ -103,7 +105,7 @@ class BackupTests(unittest.TestCase):
|
||||
|
||||
def test_non_callable_progress(self):
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, pages=1, progress='bar')
|
||||
self.assertEqual(str(cm.exception), 'progress argument must be a callable')
|
||||
|
||||
@ -116,7 +118,7 @@ class BackupTests(unittest.TestCase):
|
||||
self.cx.commit()
|
||||
journal.append(remaining)
|
||||
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, pages=1, progress=progress)
|
||||
self.verify_backup(bck)
|
||||
|
||||
@ -140,12 +142,12 @@ class BackupTests(unittest.TestCase):
|
||||
self.assertEqual(str(err.exception), 'nearly out of space')
|
||||
|
||||
def test_database_source_name(self):
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, name='main')
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, name='temp')
|
||||
with self.assertRaises(sqlite.OperationalError) as cm:
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, name='non-existing')
|
||||
self.assertIn("unknown database", str(cm.exception))
|
||||
|
||||
@ -153,7 +155,7 @@ class BackupTests(unittest.TestCase):
|
||||
self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)')
|
||||
self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)])
|
||||
self.cx.commit()
|
||||
with sqlite.connect(':memory:') as bck:
|
||||
with memory_database() as bck:
|
||||
self.cx.backup(bck, name='attached_db')
|
||||
self.verify_backup(bck)
|
||||
|
||||
|
@ -33,26 +33,13 @@ from test.support import (
|
||||
SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess,
|
||||
is_emscripten, is_wasi
|
||||
)
|
||||
from test.support import gc_collect
|
||||
from test.support import threading_helper
|
||||
from _testcapi import INT_MAX, ULLONG_MAX
|
||||
from os import SEEK_SET, SEEK_CUR, SEEK_END
|
||||
from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath
|
||||
|
||||
|
||||
# Helper for temporary memory databases
|
||||
def memory_database(*args, **kwargs):
|
||||
cx = sqlite.connect(":memory:", *args, **kwargs)
|
||||
return contextlib.closing(cx)
|
||||
|
||||
|
||||
# Temporarily limit a database connection parameter
|
||||
@contextlib.contextmanager
|
||||
def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128):
|
||||
try:
|
||||
_prev = cx.setlimit(category, limit)
|
||||
yield limit
|
||||
finally:
|
||||
cx.setlimit(category, _prev)
|
||||
from .util import memory_database, cx_limit
|
||||
|
||||
|
||||
class ModuleTests(unittest.TestCase):
|
||||
@ -326,9 +313,9 @@ class ModuleTests(unittest.TestCase):
|
||||
self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK")
|
||||
|
||||
def test_disallow_instantiation(self):
|
||||
cx = sqlite.connect(":memory:")
|
||||
check_disallow_instantiation(self, type(cx("select 1")))
|
||||
check_disallow_instantiation(self, sqlite.Blob)
|
||||
with memory_database() as cx:
|
||||
check_disallow_instantiation(self, type(cx("select 1")))
|
||||
check_disallow_instantiation(self, sqlite.Blob)
|
||||
|
||||
def test_complete_statement(self):
|
||||
self.assertFalse(sqlite.complete_statement("select t"))
|
||||
@ -342,6 +329,7 @@ class ConnectionTests(unittest.TestCase):
|
||||
cu = self.cx.cursor()
|
||||
cu.execute("create table test(id integer primary key, name text)")
|
||||
cu.execute("insert into test(name) values (?)", ("foo",))
|
||||
cu.close()
|
||||
|
||||
def tearDown(self):
|
||||
self.cx.close()
|
||||
@ -412,21 +400,22 @@ class ConnectionTests(unittest.TestCase):
|
||||
|
||||
def test_in_transaction(self):
|
||||
# Can't use db from setUp because we want to test initial state.
|
||||
cx = sqlite.connect(":memory:")
|
||||
cu = cx.cursor()
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.execute("create table transactiontest(id integer primary key, name text)")
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.execute("insert into transactiontest(name) values (?)", ("foo",))
|
||||
self.assertEqual(cx.in_transaction, True)
|
||||
cu.execute("select name from transactiontest where name=?", ["foo"])
|
||||
row = cu.fetchone()
|
||||
self.assertEqual(cx.in_transaction, True)
|
||||
cx.commit()
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.execute("select name from transactiontest where name=?", ["foo"])
|
||||
row = cu.fetchone()
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
with memory_database() as cx:
|
||||
cu = cx.cursor()
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.execute("create table transactiontest(id integer primary key, name text)")
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.execute("insert into transactiontest(name) values (?)", ("foo",))
|
||||
self.assertEqual(cx.in_transaction, True)
|
||||
cu.execute("select name from transactiontest where name=?", ["foo"])
|
||||
row = cu.fetchone()
|
||||
self.assertEqual(cx.in_transaction, True)
|
||||
cx.commit()
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.execute("select name from transactiontest where name=?", ["foo"])
|
||||
row = cu.fetchone()
|
||||
self.assertEqual(cx.in_transaction, False)
|
||||
cu.close()
|
||||
|
||||
def test_in_transaction_ro(self):
|
||||
with self.assertRaises(AttributeError):
|
||||
@ -450,10 +439,9 @@ class ConnectionTests(unittest.TestCase):
|
||||
self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
|
||||
|
||||
def test_interrupt_on_closed_db(self):
|
||||
cx = sqlite.connect(":memory:")
|
||||
cx.close()
|
||||
self.cx.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
cx.interrupt()
|
||||
self.cx.interrupt()
|
||||
|
||||
def test_interrupt(self):
|
||||
self.assertIsNone(self.cx.interrupt())
|
||||
@ -521,29 +509,29 @@ class ConnectionTests(unittest.TestCase):
|
||||
self.assertEqual(cx.isolation_level, level)
|
||||
|
||||
def test_connection_reinit(self):
|
||||
db = ":memory:"
|
||||
cx = sqlite.connect(db)
|
||||
cx.text_factory = bytes
|
||||
cx.row_factory = sqlite.Row
|
||||
cu = cx.cursor()
|
||||
cu.execute("create table foo (bar)")
|
||||
cu.executemany("insert into foo (bar) values (?)",
|
||||
((str(v),) for v in range(4)))
|
||||
cu.execute("select bar from foo")
|
||||
with memory_database() as cx:
|
||||
cx.text_factory = bytes
|
||||
cx.row_factory = sqlite.Row
|
||||
cu = cx.cursor()
|
||||
cu.execute("create table foo (bar)")
|
||||
cu.executemany("insert into foo (bar) values (?)",
|
||||
((str(v),) for v in range(4)))
|
||||
cu.execute("select bar from foo")
|
||||
|
||||
rows = [r for r in cu.fetchmany(2)]
|
||||
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
|
||||
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
|
||||
rows = [r for r in cu.fetchmany(2)]
|
||||
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
|
||||
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
|
||||
|
||||
cx.__init__(db)
|
||||
cx.execute("create table foo (bar)")
|
||||
cx.executemany("insert into foo (bar) values (?)",
|
||||
((v,) for v in ("a", "b", "c", "d")))
|
||||
cx.__init__(":memory:")
|
||||
cx.execute("create table foo (bar)")
|
||||
cx.executemany("insert into foo (bar) values (?)",
|
||||
((v,) for v in ("a", "b", "c", "d")))
|
||||
|
||||
# This uses the old database, old row factory, but new text factory
|
||||
rows = [r for r in cu.fetchall()]
|
||||
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
|
||||
self.assertEqual([r[0] for r in rows], ["2", "3"])
|
||||
# This uses the old database, old row factory, but new text factory
|
||||
rows = [r for r in cu.fetchall()]
|
||||
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
|
||||
self.assertEqual([r[0] for r in rows], ["2", "3"])
|
||||
cu.close()
|
||||
|
||||
def test_connection_bad_reinit(self):
|
||||
cx = sqlite.connect(":memory:")
|
||||
@ -591,11 +579,11 @@ class ConnectionTests(unittest.TestCase):
|
||||
"parameters in Python 3.15."
|
||||
)
|
||||
with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
|
||||
sqlite.connect(":memory:", 1.0)
|
||||
cx = sqlite.connect(":memory:", 1.0)
|
||||
cx.close()
|
||||
self.assertEqual(cm.filename, __file__)
|
||||
|
||||
|
||||
|
||||
class UninitialisedConnectionTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cx = sqlite.Connection.__new__(sqlite.Connection)
|
||||
@ -1571,12 +1559,12 @@ class ThreadTests(unittest.TestCase):
|
||||
except sqlite.Error:
|
||||
err.append("multi-threading not allowed")
|
||||
|
||||
con = sqlite.connect(":memory:", check_same_thread=False)
|
||||
err = []
|
||||
t = threading.Thread(target=run, kwargs={"con": con, "err": err})
|
||||
t.start()
|
||||
t.join()
|
||||
self.assertEqual(len(err), 0, "\n".join(err))
|
||||
with memory_database(check_same_thread=False) as con:
|
||||
err = []
|
||||
t = threading.Thread(target=run, kwargs={"con": con, "err": err})
|
||||
t.start()
|
||||
t.join()
|
||||
self.assertEqual(len(err), 0, "\n".join(err))
|
||||
|
||||
|
||||
class ConstructorTests(unittest.TestCase):
|
||||
@ -1602,9 +1590,16 @@ class ConstructorTests(unittest.TestCase):
|
||||
b = sqlite.Binary(b"\0'")
|
||||
|
||||
class ExtensionTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
self.cur = self.con.cursor()
|
||||
|
||||
def tearDown(self):
|
||||
self.cur.close()
|
||||
self.con.close()
|
||||
|
||||
def test_script_string_sql(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
cur = self.cur
|
||||
cur.executescript("""
|
||||
-- bla bla
|
||||
/* a stupid comment */
|
||||
@ -1616,40 +1611,40 @@ class ExtensionTests(unittest.TestCase):
|
||||
self.assertEqual(res, 5)
|
||||
|
||||
def test_script_syntax_error(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
with self.assertRaises(sqlite.OperationalError):
|
||||
cur.executescript("create table test(x); asdf; create table test2(x)")
|
||||
self.cur.executescript("""
|
||||
CREATE TABLE test(x);
|
||||
asdf;
|
||||
CREATE TABLE test2(x)
|
||||
""")
|
||||
|
||||
def test_script_error_normal(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
with self.assertRaises(sqlite.OperationalError):
|
||||
cur.executescript("create table test(sadfsadfdsa); select foo from hurz;")
|
||||
self.cur.executescript("""
|
||||
CREATE TABLE test(sadfsadfdsa);
|
||||
SELECT foo FROM hurz;
|
||||
""")
|
||||
|
||||
def test_cursor_executescript_as_bytes(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
with self.assertRaises(TypeError):
|
||||
cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
|
||||
self.cur.executescript(b"""
|
||||
CREATE TABLE test(foo);
|
||||
INSERT INTO test(foo) VALUES (5);
|
||||
""")
|
||||
|
||||
def test_cursor_executescript_with_null_characters(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
with self.assertRaises(ValueError):
|
||||
cur.executescript("""
|
||||
create table a(i);\0
|
||||
insert into a(i) values (5);
|
||||
""")
|
||||
self.cur.executescript("""
|
||||
CREATE TABLE a(i);\0
|
||||
INSERT INTO a(i) VALUES (5);
|
||||
""")
|
||||
|
||||
def test_cursor_executescript_with_surrogates(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
with self.assertRaises(UnicodeEncodeError):
|
||||
cur.executescript("""
|
||||
create table a(s);
|
||||
insert into a(s) values ('\ud8ff');
|
||||
""")
|
||||
self.cur.executescript("""
|
||||
CREATE TABLE a(s);
|
||||
INSERT INTO a(s) VALUES ('\ud8ff');
|
||||
""")
|
||||
|
||||
def test_cursor_executescript_too_large_script(self):
|
||||
msg = "query string is too large"
|
||||
@ -1659,19 +1654,18 @@ class ExtensionTests(unittest.TestCase):
|
||||
cx.executescript("select 'too large'".ljust(lim+1))
|
||||
|
||||
def test_cursor_executescript_tx_control(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.execute("begin")
|
||||
self.assertTrue(con.in_transaction)
|
||||
con.executescript("select 1")
|
||||
self.assertFalse(con.in_transaction)
|
||||
|
||||
def test_connection_execute(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
result = con.execute("select 5").fetchone()[0]
|
||||
result = self.con.execute("select 5").fetchone()[0]
|
||||
self.assertEqual(result, 5, "Basic test of Connection.execute")
|
||||
|
||||
def test_connection_executemany(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.execute("create table test(foo)")
|
||||
con.executemany("insert into test(foo) values (?)", [(3,), (4,)])
|
||||
result = con.execute("select foo from test order by foo").fetchall()
|
||||
@ -1679,47 +1673,44 @@ class ExtensionTests(unittest.TestCase):
|
||||
self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")
|
||||
|
||||
def test_connection_executescript(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.executescript("create table test(foo); insert into test(foo) values (5);")
|
||||
con = self.con
|
||||
con.executescript("""
|
||||
CREATE TABLE test(foo);
|
||||
INSERT INTO test(foo) VALUES (5);
|
||||
""")
|
||||
result = con.execute("select foo from test").fetchone()[0]
|
||||
self.assertEqual(result, 5, "Basic test of Connection.executescript")
|
||||
|
||||
|
||||
class ClosedConTests(unittest.TestCase):
|
||||
def check(self, fn, *args, **kwds):
|
||||
regex = "Cannot operate on a closed database."
|
||||
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
|
||||
fn(*args, **kwds)
|
||||
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
self.cur = self.con.cursor()
|
||||
self.con.close()
|
||||
|
||||
def test_closed_con_cursor(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
cur = con.cursor()
|
||||
self.check(self.con.cursor)
|
||||
|
||||
def test_closed_con_commit(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con.commit()
|
||||
self.check(self.con.commit)
|
||||
|
||||
def test_closed_con_rollback(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con.rollback()
|
||||
self.check(self.con.rollback)
|
||||
|
||||
def test_closed_cur_execute(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
con.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
cur.execute("select 4")
|
||||
self.check(self.cur.execute, "select 4")
|
||||
|
||||
def test_closed_create_function(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
def f(x): return 17
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con.create_function("foo", 1, f)
|
||||
def f(x):
|
||||
return 17
|
||||
self.check(self.con.create_function, "foo", 1, f)
|
||||
|
||||
def test_closed_create_aggregate(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
class Agg:
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -1727,29 +1718,21 @@ class ClosedConTests(unittest.TestCase):
|
||||
pass
|
||||
def finalize(self):
|
||||
return 17
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con.create_aggregate("foo", 1, Agg)
|
||||
self.check(self.con.create_aggregate, "foo", 1, Agg)
|
||||
|
||||
def test_closed_set_authorizer(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
def authorizer(*args):
|
||||
return sqlite.DENY
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con.set_authorizer(authorizer)
|
||||
self.check(self.con.set_authorizer, authorizer)
|
||||
|
||||
def test_closed_set_progress_callback(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
def progress(): pass
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con.set_progress_handler(progress, 100)
|
||||
def progress():
|
||||
pass
|
||||
self.check(self.con.set_progress_handler, progress, 100)
|
||||
|
||||
def test_closed_call(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.close()
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
con()
|
||||
self.check(self.con)
|
||||
|
||||
|
||||
class ClosedCurTests(unittest.TestCase):
|
||||
def test_closed(self):
|
||||
|
@ -2,16 +2,12 @@
|
||||
|
||||
import unittest
|
||||
import sqlite3 as sqlite
|
||||
from .test_dbapi import memory_database
|
||||
|
||||
from .util import memory_database
|
||||
from .util import MemoryDatabaseMixin
|
||||
|
||||
|
||||
class DumpTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cx = sqlite.connect(":memory:")
|
||||
self.cu = self.cx.cursor()
|
||||
|
||||
def tearDown(self):
|
||||
self.cx.close()
|
||||
class DumpTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_table_dump(self):
|
||||
expected_sqls = [
|
||||
|
@ -24,6 +24,9 @@ import unittest
|
||||
import sqlite3 as sqlite
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .util import memory_database
|
||||
from .util import MemoryDatabaseMixin
|
||||
|
||||
|
||||
def dict_factory(cursor, row):
|
||||
d = {}
|
||||
@ -45,10 +48,12 @@ class ConnectionFactoryTests(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
sqlite.Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
for factory in DefectFactory, OkFactory:
|
||||
with self.subTest(factory=factory):
|
||||
con = sqlite.connect(":memory:", factory=factory)
|
||||
self.assertIsInstance(con, factory)
|
||||
with memory_database(factory=OkFactory) as con:
|
||||
self.assertIsInstance(con, OkFactory)
|
||||
regex = "Base Connection.__init__ not called."
|
||||
with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
|
||||
with memory_database(factory=DefectFactory) as con:
|
||||
self.assertIsInstance(con, DefectFactory)
|
||||
|
||||
def test_connection_factory_relayed_call(self):
|
||||
# gh-95132: keyword args must not be passed as positional args
|
||||
@ -57,9 +62,9 @@ class ConnectionFactoryTests(unittest.TestCase):
|
||||
kwargs["isolation_level"] = None
|
||||
super(Factory, self).__init__(*args, **kwargs)
|
||||
|
||||
con = sqlite.connect(":memory:", factory=Factory)
|
||||
self.assertIsNone(con.isolation_level)
|
||||
self.assertIsInstance(con, Factory)
|
||||
with memory_database(factory=Factory) as con:
|
||||
self.assertIsNone(con.isolation_level)
|
||||
self.assertIsInstance(con, Factory)
|
||||
|
||||
def test_connection_factory_as_positional_arg(self):
|
||||
class Factory(sqlite.Connection):
|
||||
@ -74,18 +79,13 @@ class ConnectionFactoryTests(unittest.TestCase):
|
||||
r"parameters in Python 3.15."
|
||||
)
|
||||
with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
|
||||
con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory)
|
||||
with memory_database(5.0, 0, None, True, Factory) as con:
|
||||
self.assertIsNone(con.isolation_level)
|
||||
self.assertIsInstance(con, Factory)
|
||||
self.assertEqual(cm.filename, __file__)
|
||||
self.assertIsNone(con.isolation_level)
|
||||
self.assertIsInstance(con, Factory)
|
||||
|
||||
|
||||
class CursorFactoryTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_is_instance(self):
|
||||
cur = self.con.cursor()
|
||||
@ -103,9 +103,8 @@ class CursorFactoryTests(unittest.TestCase):
|
||||
# invalid callable returning non-cursor
|
||||
self.assertRaises(TypeError, self.con.cursor, lambda con: None)
|
||||
|
||||
class RowFactoryTestsBackwardsCompat(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
|
||||
class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_is_produced_by_factory(self):
|
||||
cur = self.con.cursor(factory=MyCursor)
|
||||
@ -114,12 +113,8 @@ class RowFactoryTestsBackwardsCompat(unittest.TestCase):
|
||||
self.assertIsInstance(row, dict)
|
||||
cur.close()
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
|
||||
class RowFactoryTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_custom_factory(self):
|
||||
self.con.row_factory = lambda cur, row: list(row)
|
||||
@ -265,12 +260,8 @@ class RowFactoryTests(unittest.TestCase):
|
||||
self.assertRaises(TypeError, self.con.cursor, FakeCursor)
|
||||
self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
|
||||
class TextFactoryTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_unicode(self):
|
||||
austria = "Österreich"
|
||||
@ -291,15 +282,17 @@ class TextFactoryTests(unittest.TestCase):
|
||||
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
|
||||
self.assertTrue(row[0].endswith("reich"), "column must contain original data")
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
|
||||
class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
self.con.execute("create table test (value text)")
|
||||
self.con.execute("insert into test (value) values (?)", ("a\x00b",))
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
|
||||
def test_string(self):
|
||||
# text_factory defaults to str
|
||||
row = self.con.execute("select value from test").fetchone()
|
||||
@ -325,9 +318,6 @@ class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
|
||||
self.assertIs(type(row[0]), bytes)
|
||||
self.assertEqual(row[0], b"a\x00b")
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -26,34 +26,31 @@ import unittest
|
||||
|
||||
from test.support.os_helper import TESTFN, unlink
|
||||
|
||||
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
|
||||
from test.test_sqlite3.test_userfunctions import with_tracebacks
|
||||
from .util import memory_database, cx_limit, with_tracebacks
|
||||
from .util import MemoryDatabaseMixin
|
||||
|
||||
|
||||
class CollationTests(unittest.TestCase):
|
||||
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_create_collation_not_string(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
with self.assertRaises(TypeError):
|
||||
con.create_collation(None, lambda x, y: (x > y) - (x < y))
|
||||
self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
|
||||
|
||||
def test_create_collation_not_callable(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
con.create_collation("X", 42)
|
||||
self.con.create_collation("X", 42)
|
||||
self.assertEqual(str(cm.exception), 'parameter must be callable')
|
||||
|
||||
def test_create_collation_not_ascii(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con.create_collation("collä", lambda x, y: (x > y) - (x < y))
|
||||
self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
|
||||
|
||||
def test_create_collation_bad_upper(self):
|
||||
class BadUpperStr(str):
|
||||
def upper(self):
|
||||
return None
|
||||
con = sqlite.connect(":memory:")
|
||||
mycoll = lambda x, y: -((x > y) - (x < y))
|
||||
con.create_collation(BadUpperStr("mycoll"), mycoll)
|
||||
result = con.execute("""
|
||||
self.con.create_collation(BadUpperStr("mycoll"), mycoll)
|
||||
result = self.con.execute("""
|
||||
select x from (
|
||||
select 'a' as x
|
||||
union
|
||||
@ -68,8 +65,7 @@ class CollationTests(unittest.TestCase):
|
||||
# reverse order
|
||||
return -((x > y) - (x < y))
|
||||
|
||||
con = sqlite.connect(":memory:")
|
||||
con.create_collation("mycoll", mycoll)
|
||||
self.con.create_collation("mycoll", mycoll)
|
||||
sql = """
|
||||
select x from (
|
||||
select 'a' as x
|
||||
@ -79,21 +75,20 @@ class CollationTests(unittest.TestCase):
|
||||
select 'c' as x
|
||||
) order by x collate mycoll
|
||||
"""
|
||||
result = con.execute(sql).fetchall()
|
||||
result = self.con.execute(sql).fetchall()
|
||||
self.assertEqual(result, [('c',), ('b',), ('a',)],
|
||||
msg='the expected order was not returned')
|
||||
|
||||
con.create_collation("mycoll", None)
|
||||
self.con.create_collation("mycoll", None)
|
||||
with self.assertRaises(sqlite.OperationalError) as cm:
|
||||
result = con.execute(sql).fetchall()
|
||||
result = self.con.execute(sql).fetchall()
|
||||
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
|
||||
|
||||
def test_collation_returns_large_integer(self):
|
||||
def mycoll(x, y):
|
||||
# reverse order
|
||||
return -((x > y) - (x < y)) * 2**32
|
||||
con = sqlite.connect(":memory:")
|
||||
con.create_collation("mycoll", mycoll)
|
||||
self.con.create_collation("mycoll", mycoll)
|
||||
sql = """
|
||||
select x from (
|
||||
select 'a' as x
|
||||
@ -103,7 +98,7 @@ class CollationTests(unittest.TestCase):
|
||||
select 'c' as x
|
||||
) order by x collate mycoll
|
||||
"""
|
||||
result = con.execute(sql).fetchall()
|
||||
result = self.con.execute(sql).fetchall()
|
||||
self.assertEqual(result, [('c',), ('b',), ('a',)],
|
||||
msg="the expected order was not returned")
|
||||
|
||||
@ -112,7 +107,7 @@ class CollationTests(unittest.TestCase):
|
||||
Register two different collation functions under the same name.
|
||||
Verify that the last one is actually used.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
|
||||
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
|
||||
result = con.execute("""
|
||||
@ -126,25 +121,26 @@ class CollationTests(unittest.TestCase):
|
||||
Register a collation, then deregister it. Make sure an error is raised if we try
|
||||
to use it.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
|
||||
con.create_collation("mycoll", None)
|
||||
with self.assertRaises(sqlite.OperationalError) as cm:
|
||||
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
|
||||
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
|
||||
|
||||
class ProgressTests(unittest.TestCase):
|
||||
|
||||
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_progress_handler_used(self):
|
||||
"""
|
||||
Test that the progress handler is invoked once it is set.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
progress_calls = []
|
||||
def progress():
|
||||
progress_calls.append(None)
|
||||
return 0
|
||||
con.set_progress_handler(progress, 1)
|
||||
con.execute("""
|
||||
self.con.set_progress_handler(progress, 1)
|
||||
self.con.execute("""
|
||||
create table foo(a, b)
|
||||
""")
|
||||
self.assertTrue(progress_calls)
|
||||
@ -153,7 +149,7 @@ class ProgressTests(unittest.TestCase):
|
||||
"""
|
||||
Test that the opcode argument is respected.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
progress_calls = []
|
||||
def progress():
|
||||
progress_calls.append(None)
|
||||
@ -176,11 +172,10 @@ class ProgressTests(unittest.TestCase):
|
||||
"""
|
||||
Test that returning a non-zero value stops the operation in progress.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
def progress():
|
||||
return 1
|
||||
con.set_progress_handler(progress, 1)
|
||||
curs = con.cursor()
|
||||
self.con.set_progress_handler(progress, 1)
|
||||
curs = self.con.cursor()
|
||||
self.assertRaises(
|
||||
sqlite.OperationalError,
|
||||
curs.execute,
|
||||
@ -190,7 +185,7 @@ class ProgressTests(unittest.TestCase):
|
||||
"""
|
||||
Test that setting the progress handler to None clears the previously set handler.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
action = 0
|
||||
def progress():
|
||||
nonlocal action
|
||||
@ -203,31 +198,30 @@ class ProgressTests(unittest.TestCase):
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, name="bad_progress")
|
||||
def test_error_in_progress_handler(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
def bad_progress():
|
||||
1 / 0
|
||||
con.set_progress_handler(bad_progress, 1)
|
||||
self.con.set_progress_handler(bad_progress, 1)
|
||||
with self.assertRaises(sqlite.OperationalError):
|
||||
con.execute("""
|
||||
self.con.execute("""
|
||||
create table foo(a, b)
|
||||
""")
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, name="bad_progress")
|
||||
def test_error_in_progress_handler_result(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
class BadBool:
|
||||
def __bool__(self):
|
||||
1 / 0
|
||||
def bad_progress():
|
||||
return BadBool()
|
||||
con.set_progress_handler(bad_progress, 1)
|
||||
self.con.set_progress_handler(bad_progress, 1)
|
||||
with self.assertRaises(sqlite.OperationalError):
|
||||
con.execute("""
|
||||
self.con.execute("""
|
||||
create table foo(a, b)
|
||||
""")
|
||||
|
||||
|
||||
class TraceCallbackTests(unittest.TestCase):
|
||||
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def check_stmt_trace(self, cx, expected):
|
||||
try:
|
||||
@ -242,12 +236,11 @@ class TraceCallbackTests(unittest.TestCase):
|
||||
"""
|
||||
Test that the trace callback is invoked once it is set.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
traced_statements = []
|
||||
def trace(statement):
|
||||
traced_statements.append(statement)
|
||||
con.set_trace_callback(trace)
|
||||
con.execute("create table foo(a, b)")
|
||||
self.con.set_trace_callback(trace)
|
||||
self.con.execute("create table foo(a, b)")
|
||||
self.assertTrue(traced_statements)
|
||||
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
|
||||
|
||||
@ -255,7 +248,7 @@ class TraceCallbackTests(unittest.TestCase):
|
||||
"""
|
||||
Test that setting the trace callback to None clears the previously set callback.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
traced_statements = []
|
||||
def trace(statement):
|
||||
traced_statements.append(statement)
|
||||
@ -269,7 +262,7 @@ class TraceCallbackTests(unittest.TestCase):
|
||||
Test that the statement can contain unicode literals.
|
||||
"""
|
||||
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
traced_statements = []
|
||||
def trace(statement):
|
||||
traced_statements.append(statement)
|
||||
|
@ -28,15 +28,12 @@ import functools
|
||||
|
||||
from test import support
|
||||
from unittest.mock import patch
|
||||
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
|
||||
|
||||
from .util import memory_database, cx_limit
|
||||
from .util import MemoryDatabaseMixin
|
||||
|
||||
|
||||
class RegressionTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
class RegressionTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_pragma_user_version(self):
|
||||
# This used to crash pysqlite because this pragma command returns NULL for the column name
|
||||
@ -45,28 +42,24 @@ class RegressionTests(unittest.TestCase):
|
||||
|
||||
def test_pragma_schema_version(self):
|
||||
# This still crashed pysqlite <= 2.2.1
|
||||
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
|
||||
try:
|
||||
with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con:
|
||||
cur = self.con.cursor()
|
||||
cur.execute("pragma schema_version")
|
||||
finally:
|
||||
cur.close()
|
||||
con.close()
|
||||
|
||||
def test_statement_reset(self):
|
||||
# pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
|
||||
# reset before a rollback, but only those that are still in the
|
||||
# statement cache. The others are not accessible from the connection object.
|
||||
con = sqlite.connect(":memory:", cached_statements=5)
|
||||
cursors = [con.cursor() for x in range(5)]
|
||||
cursors[0].execute("create table test(x)")
|
||||
for i in range(10):
|
||||
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
|
||||
with memory_database(cached_statements=5) as con:
|
||||
cursors = [con.cursor() for x in range(5)]
|
||||
cursors[0].execute("create table test(x)")
|
||||
for i in range(10):
|
||||
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
|
||||
|
||||
for i in range(5):
|
||||
cursors[i].execute(" " * i + "select x from test")
|
||||
for i in range(5):
|
||||
cursors[i].execute(" " * i + "select x from test")
|
||||
|
||||
con.rollback()
|
||||
con.rollback()
|
||||
|
||||
def test_column_name_with_spaces(self):
|
||||
cur = self.con.cursor()
|
||||
@ -81,17 +74,15 @@ class RegressionTests(unittest.TestCase):
|
||||
# cache when closing the database. statements that were still
|
||||
# referenced in cursors weren't closed and could provoke "
|
||||
# "OperationalError: Unable to close due to unfinalised statements".
|
||||
con = sqlite.connect(":memory:")
|
||||
cursors = []
|
||||
# default statement cache size is 100
|
||||
for i in range(105):
|
||||
cur = con.cursor()
|
||||
cur = self.con.cursor()
|
||||
cursors.append(cur)
|
||||
cur.execute("select 1 x union select " + str(i))
|
||||
con.close()
|
||||
|
||||
def test_on_conflict_rollback(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.execute("create table foo(x, unique(x) on conflict rollback)")
|
||||
con.execute("insert into foo(x) values (1)")
|
||||
try:
|
||||
@ -126,16 +117,16 @@ class RegressionTests(unittest.TestCase):
|
||||
a statement. This test exhibits the problem.
|
||||
"""
|
||||
SELECT = "select * from foo"
|
||||
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
|
||||
cur = con.cursor()
|
||||
cur.execute("create table foo(bar timestamp)")
|
||||
with self.assertWarnsRegex(DeprecationWarning, "adapter"):
|
||||
cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
|
||||
cur.execute(SELECT)
|
||||
cur.execute("drop table foo")
|
||||
cur.execute("create table foo(bar integer)")
|
||||
cur.execute("insert into foo(bar) values (5)")
|
||||
cur.execute(SELECT)
|
||||
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
|
||||
cur = con.cursor()
|
||||
cur.execute("create table foo(bar timestamp)")
|
||||
with self.assertWarnsRegex(DeprecationWarning, "adapter"):
|
||||
cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
|
||||
cur.execute(SELECT)
|
||||
cur.execute("drop table foo")
|
||||
cur.execute("create table foo(bar integer)")
|
||||
cur.execute("insert into foo(bar) values (5)")
|
||||
cur.execute(SELECT)
|
||||
|
||||
def test_bind_mutating_list(self):
|
||||
# Issue41662: Crash when mutate a list of parameters during iteration.
|
||||
@ -144,11 +135,11 @@ class RegressionTests(unittest.TestCase):
|
||||
parameters.clear()
|
||||
return "..."
|
||||
parameters = [X(), 0]
|
||||
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
|
||||
con.execute("create table foo(bar X, baz integer)")
|
||||
# Should not crash
|
||||
with self.assertRaises(IndexError):
|
||||
con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
|
||||
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
|
||||
con.execute("create table foo(bar X, baz integer)")
|
||||
# Should not crash
|
||||
with self.assertRaises(IndexError):
|
||||
con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
|
||||
|
||||
def test_error_msg_decode_error(self):
|
||||
# When porting the module to Python 3.0, the error message about
|
||||
@ -173,7 +164,7 @@ class RegressionTests(unittest.TestCase):
|
||||
def __del__(self):
|
||||
con.isolation_level = ""
|
||||
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.isolation_level = None
|
||||
for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE":
|
||||
with self.subTest(level=level):
|
||||
@ -204,8 +195,7 @@ class RegressionTests(unittest.TestCase):
|
||||
def __init__(self, con):
|
||||
pass
|
||||
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = Cursor(con)
|
||||
cur = Cursor(self.con)
|
||||
with self.assertRaises(sqlite.ProgrammingError):
|
||||
cur.execute("select 4+5").fetchall()
|
||||
with self.assertRaisesRegex(sqlite.ProgrammingError,
|
||||
@ -238,7 +228,9 @@ class RegressionTests(unittest.TestCase):
|
||||
2.5.3 introduced a regression so that these could no longer
|
||||
be created.
|
||||
"""
|
||||
con = sqlite.connect(":memory:", isolation_level=None)
|
||||
with memory_database(isolation_level=None) as con:
|
||||
self.assertIsNone(con.isolation_level)
|
||||
self.assertFalse(con.in_transaction)
|
||||
|
||||
def test_pragma_autocommit(self):
|
||||
"""
|
||||
@ -273,9 +265,7 @@ class RegressionTests(unittest.TestCase):
|
||||
Recursively using a cursor, such as when reusing it from a generator led to segfaults.
|
||||
Now we catch recursive cursor usage and raise a ProgrammingError.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
|
||||
cur = con.cursor()
|
||||
cur = self.con.cursor()
|
||||
cur.execute("create table a (bar)")
|
||||
cur.execute("create table b (baz)")
|
||||
|
||||
@ -295,29 +285,30 @@ class RegressionTests(unittest.TestCase):
|
||||
since the microsecond string "456" actually represents "456000".
|
||||
"""
|
||||
|
||||
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
|
||||
cur = con.cursor()
|
||||
cur.execute("CREATE TABLE t (x TIMESTAMP)")
|
||||
with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
|
||||
cur = con.cursor()
|
||||
cur.execute("CREATE TABLE t (x TIMESTAMP)")
|
||||
|
||||
# Microseconds should be 456000
|
||||
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
|
||||
# Microseconds should be 456000
|
||||
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
|
||||
|
||||
# Microseconds should be truncated to 123456
|
||||
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
|
||||
# Microseconds should be truncated to 123456
|
||||
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
|
||||
|
||||
cur.execute("SELECT * FROM t")
|
||||
with self.assertWarnsRegex(DeprecationWarning, "converter"):
|
||||
values = [x[0] for x in cur.fetchall()]
|
||||
cur.execute("SELECT * FROM t")
|
||||
with self.assertWarnsRegex(DeprecationWarning, "converter"):
|
||||
values = [x[0] for x in cur.fetchall()]
|
||||
|
||||
self.assertEqual(values, [
|
||||
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
|
||||
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
|
||||
])
|
||||
self.assertEqual(values, [
|
||||
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
|
||||
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
|
||||
])
|
||||
|
||||
def test_invalid_isolation_level_type(self):
|
||||
# isolation level is a string, not an integer
|
||||
self.assertRaises(TypeError,
|
||||
sqlite.connect, ":memory:", isolation_level=123)
|
||||
regex = "isolation_level must be str or None"
|
||||
with self.assertRaisesRegex(TypeError, regex):
|
||||
memory_database(isolation_level=123).__enter__()
|
||||
|
||||
|
||||
def test_null_character(self):
|
||||
@ -333,7 +324,7 @@ class RegressionTests(unittest.TestCase):
|
||||
cur.execute, query)
|
||||
|
||||
def test_surrogates(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
|
||||
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
|
||||
cur = con.cursor()
|
||||
@ -359,7 +350,7 @@ class RegressionTests(unittest.TestCase):
|
||||
to return rows multiple times when fetched from cursors
|
||||
after commit. See issues 10513 and 23129 for details.
|
||||
"""
|
||||
con = sqlite.connect(":memory:")
|
||||
con = self.con
|
||||
con.executescript("""
|
||||
create table t(c);
|
||||
create table t2(c);
|
||||
@ -391,10 +382,9 @@ class RegressionTests(unittest.TestCase):
|
||||
"""
|
||||
def callback(*args):
|
||||
pass
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = sqlite.Cursor(con)
|
||||
cur = sqlite.Cursor(self.con)
|
||||
ref = weakref.ref(cur, callback)
|
||||
cur.__init__(con)
|
||||
cur.__init__(self.con)
|
||||
del cur
|
||||
# The interpreter shouldn't crash when ref is collected.
|
||||
del ref
|
||||
@ -425,6 +415,7 @@ class RegressionTests(unittest.TestCase):
|
||||
|
||||
def test_table_lock_cursor_replace_stmt(self):
|
||||
with memory_database() as con:
|
||||
con = self.con
|
||||
cur = con.cursor()
|
||||
cur.execute("create table t(t)")
|
||||
cur.executemany("insert into t values(?)",
|
||||
|
@ -28,7 +28,8 @@ from test.support import LOOPBACK_TIMEOUT
|
||||
from test.support.os_helper import TESTFN, unlink
|
||||
from test.support.script_helper import assert_python_ok
|
||||
|
||||
from test.test_sqlite3.test_dbapi import memory_database
|
||||
from .util import memory_database
|
||||
from .util import MemoryDatabaseMixin
|
||||
|
||||
|
||||
TIMEOUT = LOOPBACK_TIMEOUT / 10
|
||||
@ -132,14 +133,14 @@ class TransactionTests(unittest.TestCase):
|
||||
|
||||
def test_rollback_cursor_consistency(self):
|
||||
"""Check that cursors behave correctly after rollback."""
|
||||
con = sqlite.connect(":memory:")
|
||||
cur = con.cursor()
|
||||
cur.execute("create table test(x)")
|
||||
cur.execute("insert into test(x) values (5)")
|
||||
cur.execute("select 1 union select 2 union select 3")
|
||||
with memory_database() as con:
|
||||
cur = con.cursor()
|
||||
cur.execute("create table test(x)")
|
||||
cur.execute("insert into test(x) values (5)")
|
||||
cur.execute("select 1 union select 2 union select 3")
|
||||
|
||||
con.rollback()
|
||||
self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
|
||||
con.rollback()
|
||||
self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
|
||||
|
||||
def test_multiple_cursors_and_iternext(self):
|
||||
# gh-94028: statements are cleared and reset in cursor iternext.
|
||||
@ -218,10 +219,7 @@ class RollbackTests(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
class SpecialCommandTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
self.cur = self.con.cursor()
|
||||
class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_drop_table(self):
|
||||
self.cur.execute("create table test(i)")
|
||||
@ -233,14 +231,8 @@ class SpecialCommandTests(unittest.TestCase):
|
||||
self.cur.execute("insert into test(i) values (5)")
|
||||
self.cur.execute("pragma count_changes=1")
|
||||
|
||||
def tearDown(self):
|
||||
self.cur.close()
|
||||
self.con.close()
|
||||
|
||||
|
||||
class TransactionalDDL(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def test_ddl_does_not_autostart_transaction(self):
|
||||
# For backwards compatibility reasons, DDL statements should not
|
||||
@ -268,9 +260,6 @@ class TransactionalDDL(unittest.TestCase):
|
||||
with self.assertRaises(sqlite.OperationalError):
|
||||
self.con.execute("select * from test")
|
||||
|
||||
def tearDown(self):
|
||||
self.con.close()
|
||||
|
||||
|
||||
class IsolationLevelFromInit(unittest.TestCase):
|
||||
CREATE = "create table t(t)"
|
||||
|
@ -21,54 +21,15 @@
|
||||
# misrepresented as being the original software.
|
||||
# 3. This notice may not be removed or altered from any source distribution.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import io
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
import sqlite3 as sqlite
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from test.support import bigmemtest, catch_unraisable_exception, gc_collect
|
||||
from test.support import bigmemtest, gc_collect
|
||||
|
||||
from test.test_sqlite3.test_dbapi import cx_limit
|
||||
|
||||
|
||||
def with_tracebacks(exc, regex="", name=""):
|
||||
"""Convenience decorator for testing callback tracebacks."""
|
||||
def decorator(func):
|
||||
_regex = re.compile(regex) if regex else None
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with catch_unraisable_exception() as cm:
|
||||
# First, run the test with traceback enabled.
|
||||
with check_tracebacks(self, cm, exc, _regex, name):
|
||||
func(self, *args, **kwargs)
|
||||
|
||||
# Then run the test with traceback disabled.
|
||||
func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def check_tracebacks(self, cm, exc, regex, obj_name):
|
||||
"""Convenience context manager for testing callback tracebacks."""
|
||||
sqlite.enable_callback_tracebacks(True)
|
||||
try:
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stderr(buf):
|
||||
yield
|
||||
|
||||
self.assertEqual(cm.unraisable.exc_type, exc)
|
||||
if regex:
|
||||
msg = str(cm.unraisable.exc_value)
|
||||
self.assertIsNotNone(regex.search(msg))
|
||||
if obj_name:
|
||||
self.assertEqual(cm.unraisable.object.__name__, obj_name)
|
||||
finally:
|
||||
sqlite.enable_callback_tracebacks(False)
|
||||
from .util import cx_limit, memory_database
|
||||
from .util import with_tracebacks, check_tracebacks
|
||||
|
||||
|
||||
def func_returntext():
|
||||
@ -405,19 +366,19 @@ class FunctionTests(unittest.TestCase):
|
||||
def test_function_destructor_via_gc(self):
|
||||
# See bpo-44304: The destructor of the user function can
|
||||
# crash if is called without the GIL from the gc functions
|
||||
dest = sqlite.connect(':memory:')
|
||||
def md5sum(t):
|
||||
return
|
||||
|
||||
dest.create_function("md5", 1, md5sum)
|
||||
x = dest("create table lang (name, first_appeared)")
|
||||
del md5sum, dest
|
||||
with memory_database() as dest:
|
||||
dest.create_function("md5", 1, md5sum)
|
||||
x = dest("create table lang (name, first_appeared)")
|
||||
del md5sum, dest
|
||||
|
||||
y = [x]
|
||||
y.append(y)
|
||||
y = [x]
|
||||
y.append(y)
|
||||
|
||||
del x,y
|
||||
gc_collect()
|
||||
del x,y
|
||||
gc_collect()
|
||||
|
||||
@with_tracebacks(OverflowError)
|
||||
def test_func_return_too_large_int(self):
|
||||
@ -514,6 +475,10 @@ class WindowFunctionTests(unittest.TestCase):
|
||||
"""
|
||||
self.con.create_window_function("sumint", 1, WindowSumInt)
|
||||
|
||||
def tearDown(self):
|
||||
self.cur.close()
|
||||
self.con.close()
|
||||
|
||||
def test_win_sum_int(self):
|
||||
self.cur.execute(self.query % "sumint")
|
||||
self.assertEqual(self.cur.fetchall(), self.expected)
|
||||
@ -634,6 +599,7 @@ class AggregateTests(unittest.TestCase):
|
||||
""")
|
||||
cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
|
||||
("foo", 5, 3.14, None, memoryview(b"blob"),))
|
||||
cur.close()
|
||||
|
||||
self.con.create_aggregate("nostep", 1, AggrNoStep)
|
||||
self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
|
||||
@ -646,9 +612,7 @@ class AggregateTests(unittest.TestCase):
|
||||
self.con.create_aggregate("aggtxt", 1, AggrText)
|
||||
|
||||
def tearDown(self):
|
||||
#self.cur.close()
|
||||
#self.con.close()
|
||||
pass
|
||||
self.con.close()
|
||||
|
||||
def test_aggr_error_on_create(self):
|
||||
with self.assertRaises(sqlite.OperationalError):
|
||||
@ -775,7 +739,7 @@ class AuthorizerTests(unittest.TestCase):
|
||||
self.con.set_authorizer(self.authorizer_cb)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
self.con.close()
|
||||
|
||||
def test_table_access(self):
|
||||
with self.assertRaises(sqlite.DatabaseError) as cm:
|
||||
|
78
Lib/test/test_sqlite3/util.py
Normal file
78
Lib/test/test_sqlite3/util.py
Normal file
@ -0,0 +1,78 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import io
|
||||
import re
|
||||
import sqlite3
|
||||
import test.support
|
||||
import unittest
|
||||
|
||||
|
||||
# Helper for temporary memory databases
|
||||
def memory_database(*args, **kwargs):
|
||||
cx = sqlite3.connect(":memory:", *args, **kwargs)
|
||||
return contextlib.closing(cx)
|
||||
|
||||
|
||||
# Temporarily limit a database connection parameter
|
||||
@contextlib.contextmanager
|
||||
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
|
||||
try:
|
||||
_prev = cx.setlimit(category, limit)
|
||||
yield limit
|
||||
finally:
|
||||
cx.setlimit(category, _prev)
|
||||
|
||||
|
||||
def with_tracebacks(exc, regex="", name=""):
|
||||
"""Convenience decorator for testing callback tracebacks."""
|
||||
def decorator(func):
|
||||
_regex = re.compile(regex) if regex else None
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with test.support.catch_unraisable_exception() as cm:
|
||||
# First, run the test with traceback enabled.
|
||||
with check_tracebacks(self, cm, exc, _regex, name):
|
||||
func(self, *args, **kwargs)
|
||||
|
||||
# Then run the test with traceback disabled.
|
||||
func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def check_tracebacks(self, cm, exc, regex, obj_name):
|
||||
"""Convenience context manager for testing callback tracebacks."""
|
||||
sqlite3.enable_callback_tracebacks(True)
|
||||
try:
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stderr(buf):
|
||||
yield
|
||||
|
||||
self.assertEqual(cm.unraisable.exc_type, exc)
|
||||
if regex:
|
||||
msg = str(cm.unraisable.exc_value)
|
||||
self.assertIsNotNone(regex.search(msg))
|
||||
if obj_name:
|
||||
self.assertEqual(cm.unraisable.object.__name__, obj_name)
|
||||
finally:
|
||||
sqlite3.enable_callback_tracebacks(False)
|
||||
|
||||
|
||||
class MemoryDatabaseMixin:
|
||||
|
||||
def setUp(self):
|
||||
self.con = sqlite3.connect(":memory:")
|
||||
self.cur = self.con.cursor()
|
||||
|
||||
def tearDown(self):
|
||||
self.cur.close()
|
||||
self.con.close()
|
||||
|
||||
@property
|
||||
def cx(self):
|
||||
return self.con
|
||||
|
||||
@property
|
||||
def cu(self):
|
||||
return self.cur
|
Loading…
Reference in New Issue
Block a user