cpython/Lib/test/pickletester.py
Guido van Rossum 5d9113d8be Implement appropriate __getnewargs__ for all immutable subclassable builtin
types.  The special handling for these can now be removed from save_newobj().
Add some testing for this.

Also add support for setting the 'fast' flag on the Python Pickler class,
which suppresses use of the memo.
2003-01-29 17:58:45 +00:00

513 lines
14 KiB
Python

import unittest
from test.test_support import TestFailed, have_unicode, TESTFN
class C:
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
import __main__
__main__.C = C
C.__module__ = "__main__"
class myint(int):
def __init__(self, x):
self.str = str(x)
class initarg(C):
__safe_for_unpickling__ = 1
def __init__(self, a, b):
self.a = a
self.b = b
def __getinitargs__(self):
return self.a, self.b
class metaclass(type):
pass
class use_metaclass(object):
__metaclass__ = metaclass
# break into multiple strings to avoid confusing font-lock-mode
DATA = """(lp1
I0
aL1L
aF2
ac__builtin__
complex
p2
""" + \
"""(F3
F0
tRp3
aI1
aI-1
aI255
aI-255
aI-256
aI65535
aI-65535
aI-65536
aI2147483647
aI-2147483647
aI-2147483648
a""" + \
"""(S'abc'
p4
g4
""" + \
"""(i__main__
C
p5
""" + \
"""(dp6
S'foo'
p7
I1
sS'bar'
p8
I2
sbg5
tp9
ag9
aI5
a.
"""
BINDATA = ']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00' + \
'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00' + \
'\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff' + \
'\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff' + \
'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00' + \
'\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n' + \
'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' + \
'\x06tq\nh\nK\x05e.'
def create_data():
c = C()
c.foo = 1
c.bar = 2
x = [0, 1L, 2.0, 3.0+0j]
# Append some integer test cases at cPickle.c's internal size
# cutoffs.
uint1max = 0xff
uint2max = 0xffff
int4max = 0x7fffffff
x.extend([1, -1,
uint1max, -uint1max, -uint1max-1,
uint2max, -uint2max, -uint2max-1,
int4max, -int4max, -int4max-1])
y = ('abc', 'abc', c, c)
x.append(y)
x.append(y)
x.append(5)
return x
class AbstractPickleTests(unittest.TestCase):
_testdata = create_data()
def setUp(self):
# subclass must define self.dumps, self.loads, self.error
pass
def test_misc(self):
# test various datatypes not tested by testdata
x = myint(4)
s = self.dumps(x)
y = self.loads(s)
self.assertEqual(x, y)
x = (1, ())
s = self.dumps(x)
y = self.loads(s)
self.assertEqual(x, y)
x = initarg(1, x)
s = self.dumps(x)
y = self.loads(s)
self.assertEqual(x, y)
# XXX test __reduce__ protocol?
def test_identity(self):
s = self.dumps(self._testdata)
x = self.loads(s)
self.assertEqual(x, self._testdata)
def test_constant(self):
x = self.loads(DATA)
self.assertEqual(x, self._testdata)
x = self.loads(BINDATA)
self.assertEqual(x, self._testdata)
def test_binary(self):
s = self.dumps(self._testdata, 1)
x = self.loads(s)
self.assertEqual(x, self._testdata)
def test_recursive_list(self):
l = []
l.append(l)
s = self.dumps(l)
x = self.loads(s)
self.assertEqual(x, l)
self.assertEqual(x, x[0])
self.assertEqual(id(x), id(x[0]))
def test_recursive_dict(self):
d = {}
d[1] = d
s = self.dumps(d)
x = self.loads(s)
self.assertEqual(x, d)
self.assertEqual(x[1], x)
self.assertEqual(id(x[1]), id(x))
def test_recursive_inst(self):
i = C()
i.attr = i
s = self.dumps(i)
x = self.loads(s)
self.assertEqual(x, i)
self.assertEqual(x.attr, x)
self.assertEqual(id(x.attr), id(x))
def test_recursive_multi(self):
l = []
d = {1:l}
i = C()
i.attr = d
l.append(i)
s = self.dumps(l)
x = self.loads(s)
self.assertEqual(x, l)
self.assertEqual(x[0], i)
self.assertEqual(x[0].attr, d)
self.assertEqual(x[0].attr[1], x)
self.assertEqual(x[0].attr[1][0], i)
self.assertEqual(x[0].attr[1][0].attr, d)
def test_garyp(self):
self.assertRaises(self.error, self.loads, 'garyp')
def test_insecure_strings(self):
insecure = ["abc", "2 + 2", # not quoted
#"'abc' + 'def'", # not a single quoted string
"'abc", # quote is not closed
"'abc\"", # open quote and close quote don't match
"'abc' ?", # junk after close quote
"'\\'", # trailing backslash
# some tests of the quoting rules
#"'abc\"\''",
#"'\\\\a\'\'\'\\\'\\\\\''",
]
for s in insecure:
buf = "S" + s + "\012p0\012."
self.assertRaises(ValueError, self.loads, buf)
if have_unicode:
def test_unicode(self):
endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
unicode('<\n>'), unicode('<\\>')]
for u in endcases:
p = self.dumps(u)
u2 = self.loads(p)
self.assertEqual(u2, u)
def test_ints(self):
import sys
n = sys.maxint
while n:
for expected in (-n, n):
s = self.dumps(expected)
n2 = self.loads(s)
self.assertEqual(expected, n2)
n = n >> 1
def test_maxint64(self):
maxint64 = (1L << 63) - 1
data = 'I' + str(maxint64) + '\n.'
got = self.loads(data)
self.assertEqual(got, maxint64)
# Try too with a bogus literal.
data = 'I' + str(maxint64) + 'JUNK\n.'
self.assertRaises(ValueError, self.loads, data)
def test_reduce(self):
pass
def test_getinitargs(self):
pass
def test_metaclass(self):
a = use_metaclass()
s = self.dumps(a)
b = self.loads(s)
self.assertEqual(a.__class__, b.__class__)
def test_structseq(self):
import time
t = time.localtime()
s = self.dumps(t)
u = self.loads(s)
self.assertEqual(t, u)
import os
if hasattr(os, "stat"):
t = os.stat(os.curdir)
s = self.dumps(t)
u = self.loads(s)
self.assertEqual(t, u)
if hasattr(os, "statvfs"):
t = os.statvfs(os.curdir)
s = self.dumps(t)
u = self.loads(s)
self.assertEqual(t, u)
# Tests for protocol 2
def test_long1(self):
x = 12345678910111213141516178920L
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(x, y)
def test_long4(self):
x = 12345678910111213141516178920L << (256*8)
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(x, y)
def test_short_tuples(self):
a = ()
b = (1,)
c = (1, 2)
d = (1, 2, 3)
e = (1, 2, 3, 4)
for proto in 0, 1, 2:
for x in a, b, c, d, e:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y, (proto, x, s, y))
def test_singletons(self):
for proto in 0, 1, 2:
for x in None, False, True:
s = self.dumps(x, proto)
y = self.loads(s)
self.assert_(x is y, (proto, x, s, y))
def test_newobj_tuple(self):
x = MyTuple([1, 2, 3])
x.foo = 42
x.bar = "hello"
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(tuple(x), tuple(y))
self.assertEqual(x.__dict__, y.__dict__)
## import pickletools
## print
## pickletools.dis(s)
def test_newobj_list(self):
x = MyList([1, 2, 3])
x.foo = 42
x.bar = "hello"
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
## import pickletools
## print
## pickletools.dis(s)
def test_newobj_generic(self):
for proto in [0, 1, 2]:
for C in myclasses:
B = C.__base__
x = C(C.sample)
x.foo = 42
s = self.dumps(x, proto)
## import pickletools
## print
## pickletools.dis(s)
y = self.loads(s)
detail = (proto, C, B, x, y, type(y))
self.assertEqual(B(x), B(y), detail)
self.assertEqual(x.__dict__, y.__dict__, detail)
# XXX Temporary hack, so long as the C implementation of pickle protocol
# XXX 2 isn't ready. When it is, move the methods in TempAbstractPickleTests
# XXX into AbstractPickleTests above, and get rid of TempAbstractPickleTests
# XXX along with the references to it in test_pickle.py.
class TempAbstractPickleTests(unittest.TestCase):
def test_newobj_list_slots(self):
x = SlotList([1, 2, 3])
x.foo = 42
x.bar = "hello"
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assertEqual(x.foo, y.foo)
self.assertEqual(x.bar, y.bar)
## import pickletools
## print
## pickletools.dis(s)
def test_global_ext1(self):
import copy_reg
copy_reg.add_extension(__name__, "MyList", 0xf0)
try:
x = MyList([1, 2, 3])
x.foo = 42
x.bar = "hello"
# Dump using protocol 1 for comparison
s1 = self.dumps(x, 1)
y = self.loads(s1)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assert_(s1.find(__name__) >= 0)
self.assert_(s1.find("MyList") >= 0)
## import pickletools
## print
## pickletools.dis(s1)
# Dump using protocol 2 for test
s2 = self.dumps(x, 2)
self.assertEqual(s2.find(__name__), -1)
self.assertEqual(s2.find("MyList"), -1)
y = self.loads(s2)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
## import pickletools
## print
## pickletools.dis(s2)
finally:
copy_reg.remove_extension(__name__, "MyList", 0xf0)
def test_global_ext2(self):
import copy_reg
copy_reg.add_extension(__name__, "MyList", 0xfff0)
try:
x = MyList()
s2 = self.dumps(x, 2)
self.assertEqual(s2.find(__name__), -1)
self.assertEqual(s2.find("MyList"), -1)
y = self.loads(s2)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
finally:
copy_reg.remove_extension(__name__, "MyList", 0xfff0)
def test_global_ext4(self):
import copy_reg
copy_reg.add_extension(__name__, "MyList", 0xfffff0)
try:
x = MyList()
s2 = self.dumps(x, 2)
self.assertEqual(s2.find(__name__), -1)
self.assertEqual(s2.find("MyList"), -1)
y = self.loads(s2)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
finally:
copy_reg.remove_extension(__name__, "MyList", 0xfffff0)
class MyInt(int):
sample = 1
class MyLong(long):
sample = 1L
class MyFloat(float):
sample = 1.0
class MyComplex(complex):
sample = 1.0 + 0.0j
class MyStr(str):
sample = "hello"
class MyUnicode(unicode):
sample = u"hello \u1234"
class MyTuple(tuple):
sample = (1, 2, 3)
class MyList(list):
sample = [1, 2, 3]
class MyDict(dict):
sample = {"a": 1, "b": 2}
myclasses = [MyInt, MyLong, MyFloat,
# MyComplex, # XXX complex somehow doesn't work here :-(
MyStr, MyUnicode,
MyTuple, MyList, MyDict]
class SlotList(MyList):
__slots__ = ["foo"]
class AbstractPickleModuleTests(unittest.TestCase):
def test_dump_closed_file(self):
import os
f = open(TESTFN, "w")
try:
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
finally:
os.remove(TESTFN)
def test_load_closed_file(self):
import os
f = open(TESTFN, "w")
try:
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
finally:
os.remove(TESTFN)
class AbstractPersistentPicklerTests(unittest.TestCase):
# This class defines persistent_id() and persistent_load()
# functions that should be used by the pickler. All even integers
# are pickled using persistent ids.
def persistent_id(self, object):
if isinstance(object, int) and object % 2 == 0:
self.id_count += 1
return str(object)
else:
return None
def persistent_load(self, oid):
self.load_count += 1
object = int(oid)
assert object % 2 == 0
return object
def test_persistence(self):
self.id_count = 0
self.load_count = 0
L = range(10)
self.assertEqual(self.loads(self.dumps(L)), L)
self.assertEqual(self.id_count, 5)
self.assertEqual(self.load_count, 5)
def test_bin_persistence(self):
self.id_count = 0
self.load_count = 0
L = range(10)
self.assertEqual(self.loads(self.dumps(L, 1)), L)
self.assertEqual(self.id_count, 5)
self.assertEqual(self.load_count, 5)