cpython/Lib/test/pickletester.py
Amaury Forgeot d'Arc 3e4e72f66f #4298: pickle.load() can segfault on invalid or truncated input.
Patch and test by Hirokazu Yamamoto.
2008-11-11 20:05:06 +00:00

1093 lines
34 KiB
Python

import unittest
import pickle
import pickletools
import copyreg
from test.support import TestFailed, TESTFN, run_with_locale
from pickle import bytes_types
# Tests that try a number of pickle protocols should have a
# for proto in protocols:
# kind of outer loop.
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
# Return True if opcode code appears in the pickle, else False.
def opcode_in_pickle(code, pickle):
for op, dummy, dummy in pickletools.genops(pickle):
if op.code == code.decode("latin-1"):
return True
return False
# Return the number of times opcode code appears in pickle.
def count_opcode(code, pickle):
n = 0
for op, dummy, dummy in pickletools.genops(pickle):
if op.code == code.decode("latin-1"):
n += 1
return n
# We can't very well test the extension registry without putting known stuff
# in it, but we have to be careful to restore its original state. Code
# should do this:
#
# e = ExtensionSaver(extension_code)
# try:
# fiddle w/ the extension registry's stuff for extension_code
# finally:
# e.restore()
class ExtensionSaver:
# Remember current registration for code (if any), and remove it (if
# there is one).
def __init__(self, code):
self.code = code
if code in copyreg._inverted_registry:
self.pair = copyreg._inverted_registry[code]
copyreg.remove_extension(self.pair[0], self.pair[1], code)
else:
self.pair = None
# Restore previous registration for code.
def restore(self):
code = self.code
curpair = copyreg._inverted_registry.get(code)
if curpair is not None:
copyreg.remove_extension(curpair[0], curpair[1], code)
pair = self.pair
if pair is not None:
copyreg.add_extension(pair[0], pair[1], code)
class C:
def __eq__(self, other):
return self.__dict__ == other.__dict__
import __main__
__main__.C = C
C.__module__ = "__main__"
class myint(int):
def __init__(self, x):
self.str = str(x)
class initarg(C):
def __init__(self, a, b):
self.a = a
self.b = b
def __getinitargs__(self):
return self.a, self.b
class metaclass(type):
pass
class use_metaclass(object, metaclass=metaclass):
pass
# DATA0 .. DATA2 are the pickles we expect under the various protocols, for
# the object returned by create_data().
DATA0 = (
b'(lp0\nL0\naL1\naF2.0\nac'
b'builtins\ncomplex\n'
b'p1\n(F3.0\nF0.0\ntp2\nRp'
b'3\naL1\naL-1\naL255\naL-'
b'255\naL-256\naL65535\na'
b'L-65535\naL-65536\naL2'
b'147483647\naL-2147483'
b'647\naL-2147483648\na('
b'Vabc\np4\ng4\nccopyreg'
b'\n_reconstructor\np5\n('
b'c__main__\nC\np6\ncbu'
b'iltins\nobject\np7\nNt'
b'p8\nRp9\n(dp10\nVfoo\np1'
b'1\nL1\nsVbar\np12\nL2\nsb'
b'g9\ntp13\nag13\naL5\na.'
)
# Disassembly of DATA0
DATA0_DIS = """\
0: ( MARK
1: l LIST (MARK at 0)
2: p PUT 0
5: L LONG 0
8: a APPEND
9: L LONG 1
12: a APPEND
13: F FLOAT 2.0
18: a APPEND
19: c GLOBAL 'builtins complex'
37: p PUT 1
40: ( MARK
41: F FLOAT 3.0
46: F FLOAT 0.0
51: t TUPLE (MARK at 40)
52: p PUT 2
55: R REDUCE
56: p PUT 3
59: a APPEND
60: L LONG 1
63: a APPEND
64: L LONG -1
68: a APPEND
69: L LONG 255
74: a APPEND
75: L LONG -255
81: a APPEND
82: L LONG -256
88: a APPEND
89: L LONG 65535
96: a APPEND
97: L LONG -65535
105: a APPEND
106: L LONG -65536
114: a APPEND
115: L LONG 2147483647
127: a APPEND
128: L LONG -2147483647
141: a APPEND
142: L LONG -2147483648
155: a APPEND
156: ( MARK
157: V UNICODE 'abc'
162: p PUT 4
165: g GET 4
168: c GLOBAL 'copyreg _reconstructor'
192: p PUT 5
195: ( MARK
196: c GLOBAL '__main__ C'
208: p PUT 6
211: c GLOBAL 'builtins object'
228: p PUT 7
231: N NONE
232: t TUPLE (MARK at 195)
233: p PUT 8
236: R REDUCE
237: p PUT 9
240: ( MARK
241: d DICT (MARK at 240)
242: p PUT 10
246: V UNICODE 'foo'
251: p PUT 11
255: L LONG 1
258: s SETITEM
259: V UNICODE 'bar'
264: p PUT 12
268: L LONG 2
271: s SETITEM
272: b BUILD
273: g GET 9
276: t TUPLE (MARK at 156)
277: p PUT 13
281: a APPEND
282: g GET 13
286: a APPEND
287: L LONG 5
290: a APPEND
291: . STOP
highest protocol among opcodes = 0
"""
DATA1 = (
b']q\x00(K\x00K\x01G@\x00\x00\x00\x00\x00\x00\x00c'
b'builtins\ncomplex\nq\x01'
b'(G@\x08\x00\x00\x00\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00\x00t'
b'q\x02Rq\x03K\x01J\xff\xff\xff\xffK\xffJ\x01\xff\xff\xffJ'
b'\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff'
b'\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00\x80(X\x03\x00\x00\x00ab'
b'cq\x04h\x04ccopyreg\n_reco'
b'nstructor\nq\x05(c__main'
b'__\nC\nq\x06cbuiltins\n'
b'object\nq\x07Ntq\x08Rq\t}q\n('
b'X\x03\x00\x00\x00fooq\x0bK\x01X\x03\x00\x00\x00bar'
b'q\x0cK\x02ubh\ttq\rh\rK\x05e.'
)
# Disassembly of DATA1
DATA1_DIS = """\
0: ] EMPTY_LIST
1: q BINPUT 0
3: ( MARK
4: K BININT1 0
6: K BININT1 1
8: G BINFLOAT 2.0
17: c GLOBAL 'builtins complex'
35: q BINPUT 1
37: ( MARK
38: G BINFLOAT 3.0
47: G BINFLOAT 0.0
56: t TUPLE (MARK at 37)
57: q BINPUT 2
59: R REDUCE
60: q BINPUT 3
62: K BININT1 1
64: J BININT -1
69: K BININT1 255
71: J BININT -255
76: J BININT -256
81: M BININT2 65535
84: J BININT -65535
89: J BININT -65536
94: J BININT 2147483647
99: J BININT -2147483647
104: J BININT -2147483648
109: ( MARK
110: X BINUNICODE 'abc'
118: q BINPUT 4
120: h BINGET 4
122: c GLOBAL 'copyreg _reconstructor'
146: q BINPUT 5
148: ( MARK
149: c GLOBAL '__main__ C'
161: q BINPUT 6
163: c GLOBAL 'builtins object'
180: q BINPUT 7
182: N NONE
183: t TUPLE (MARK at 148)
184: q BINPUT 8
186: R REDUCE
187: q BINPUT 9
189: } EMPTY_DICT
190: q BINPUT 10
192: ( MARK
193: X BINUNICODE 'foo'
201: q BINPUT 11
203: K BININT1 1
205: X BINUNICODE 'bar'
213: q BINPUT 12
215: K BININT1 2
217: u SETITEMS (MARK at 192)
218: b BUILD
219: h BINGET 9
221: t TUPLE (MARK at 109)
222: q BINPUT 13
224: h BINGET 13
226: K BININT1 5
228: e APPENDS (MARK at 3)
229: . STOP
highest protocol among opcodes = 1
"""
DATA2 = (
b'\x80\x02]q\x00(K\x00K\x01G@\x00\x00\x00\x00\x00\x00\x00c'
b'builtins\ncomplex\n'
b'q\x01G@\x08\x00\x00\x00\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x86q\x02Rq\x03K\x01J\xff\xff\xff\xffK\xffJ\x01\xff\xff\xff'
b'J\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff'
b'\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00\x80(X\x03\x00\x00\x00a'
b'bcq\x04h\x04c__main__\nC\nq\x05'
b')\x81q\x06}q\x07(X\x03\x00\x00\x00fooq\x08K\x01'
b'X\x03\x00\x00\x00barq\tK\x02ubh\x06tq\nh'
b'\nK\x05e.'
)
# Disassembly of DATA2
DATA2_DIS = """\
0: \x80 PROTO 2
2: ] EMPTY_LIST
3: q BINPUT 0
5: ( MARK
6: K BININT1 0
8: K BININT1 1
10: G BINFLOAT 2.0
19: c GLOBAL 'builtins complex'
37: q BINPUT 1
39: G BINFLOAT 3.0
48: G BINFLOAT 0.0
57: \x86 TUPLE2
58: q BINPUT 2
60: R REDUCE
61: q BINPUT 3
63: K BININT1 1
65: J BININT -1
70: K BININT1 255
72: J BININT -255
77: J BININT -256
82: M BININT2 65535
85: J BININT -65535
90: J BININT -65536
95: J BININT 2147483647
100: J BININT -2147483647
105: J BININT -2147483648
110: ( MARK
111: X BINUNICODE 'abc'
119: q BINPUT 4
121: h BINGET 4
123: c GLOBAL '__main__ C'
135: q BINPUT 5
137: ) EMPTY_TUPLE
138: \x81 NEWOBJ
139: q BINPUT 6
141: } EMPTY_DICT
142: q BINPUT 7
144: ( MARK
145: X BINUNICODE 'foo'
153: q BINPUT 8
155: K BININT1 1
157: X BINUNICODE 'bar'
165: q BINPUT 9
167: K BININT1 2
169: u SETITEMS (MARK at 144)
170: b BUILD
171: h BINGET 6
173: t TUPLE (MARK at 110)
174: q BINPUT 10
176: h BINGET 10
178: K BININT1 5
180: e APPENDS (MARK at 5)
181: . STOP
highest protocol among opcodes = 2
"""
def create_data():
c = C()
c.foo = 1
c.bar = 2
x = [0, 1, 2.0, 3.0+0j]
# Append some integer test cases at cPickle.c's internal size
# cutoffs.
uint1max = 0xff
uint2max = 0xffff
int4max = 0x7fffffff
x.extend([1, -1,
uint1max, -uint1max, -uint1max-1,
uint2max, -uint2max, -uint2max-1,
int4max, -int4max, -int4max-1])
y = ('abc', 'abc', c, c)
x.append(y)
x.append(y)
x.append(5)
return x
class AbstractPickleTests(unittest.TestCase):
# Subclass must define self.dumps, self.loads.
_testdata = create_data()
def setUp(self):
pass
def test_misc(self):
# test various datatypes not tested by testdata
for proto in protocols:
x = myint(4)
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
x = (1, ())
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
x = initarg(1, x)
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
# XXX test __reduce__ protocol?
def test_roundtrip_equality(self):
expected = self._testdata
for proto in protocols:
s = self.dumps(expected, proto)
got = self.loads(s)
self.assertEqual(expected, got)
def test_load_from_data0(self):
self.assertEqual(self._testdata, self.loads(DATA0))
def test_load_from_data1(self):
self.assertEqual(self._testdata, self.loads(DATA1))
def test_load_from_data2(self):
self.assertEqual(self._testdata, self.loads(DATA2))
# There are gratuitous differences between pickles produced by
# pickle and cPickle, largely because cPickle starts PUT indices at
# 1 and pickle starts them at 0. See XXX comment in cPickle's put2() --
# there's a comment with an exclamation point there whose meaning
# is a mystery. cPickle also suppresses PUT for objects with a refcount
# of 1.
def dont_test_disassembly(self):
from io import StringIO
from pickletools import dis
for proto, expected in (0, DATA0_DIS), (1, DATA1_DIS):
s = self.dumps(self._testdata, proto)
filelike = StringIO()
dis(s, out=filelike)
got = filelike.getvalue()
self.assertEqual(expected, got)
def test_recursive_list(self):
l = []
l.append(l)
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
self.assertEqual(len(x), 1)
self.assert_(x is x[0])
def test_recursive_dict(self):
d = {}
d[1] = d
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
self.assertEqual(list(x.keys()), [1])
self.assert_(x[1] is x)
def test_recursive_inst(self):
i = C()
i.attr = i
for proto in protocols:
s = self.dumps(i, 2)
x = self.loads(s)
self.assertEqual(dir(x), dir(i))
self.assert_(x.attr is x)
def test_recursive_multi(self):
l = []
d = {1:l}
i = C()
i.attr = d
l.append(i)
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
self.assertEqual(len(x), 1)
self.assertEqual(dir(x[0]), dir(i))
self.assertEqual(list(x[0].attr.keys()), [1])
self.assert_(x[0].attr[1] is x)
def test_get(self):
self.assertRaises(KeyError, self.loads, b'g0\np0')
self.assertEquals(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)])
def test_insecure_strings(self):
# XXX Some of these tests are temporarily disabled
insecure = [b"abc", b"2 + 2", # not quoted
## b"'abc' + 'def'", # not a single quoted string
b"'abc", # quote is not closed
b"'abc\"", # open quote and close quote don't match
b"'abc' ?", # junk after close quote
b"'\\'", # trailing backslash
# some tests of the quoting rules
## b"'abc\"\''",
## b"'\\\\a\'\'\'\\\'\\\\\''",
]
for b in insecure:
buf = b"S" + b + b"\012p0\012."
self.assertRaises(ValueError, self.loads, buf)
def test_unicode(self):
endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', '<\\>']
for proto in protocols:
for u in endcases:
p = self.dumps(u, proto)
u2 = self.loads(p)
self.assertEqual(u2, u)
def test_bytes(self):
for proto in protocols:
for u in b'', b'xyz', b'xyz'*100:
p = self.dumps(u)
self.assertEqual(self.loads(p), u)
def test_ints(self):
import sys
for proto in protocols:
n = sys.maxsize
while n:
for expected in (-n, n):
s = self.dumps(expected, proto)
n2 = self.loads(s)
self.assertEqual(expected, n2)
n = n >> 1
def test_maxint64(self):
maxint64 = (1 << 63) - 1
data = b'I' + str(maxint64).encode("ascii") + b'\n.'
got = self.loads(data)
self.assertEqual(got, maxint64)
# Try too with a bogus literal.
data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.'
self.assertRaises(ValueError, self.loads, data)
def test_long(self):
for proto in protocols:
# 256 bytes is where LONG4 begins.
for nbits in 1, 8, 8*254, 8*255, 8*256, 8*257:
nbase = 1 << nbits
for npos in nbase-1, nbase, nbase+1:
for n in npos, -npos:
pickle = self.dumps(n, proto)
got = self.loads(pickle)
self.assertEqual(n, got)
# Try a monster. This is quadratic-time in protos 0 & 1, so don't
# bother with those.
nbase = int("deadbeeffeedface", 16)
nbase += nbase << 1000000
for n in nbase, -nbase:
p = self.dumps(n, 2)
got = self.loads(p)
self.assertEqual(n, got)
@run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
def test_float_format(self):
# make sure that floats are formatted locale independent with proto 0
self.assertEqual(self.dumps(1.2, 0)[0:3], b'F1.')
def test_reduce(self):
pass
def test_getinitargs(self):
pass
def test_metaclass(self):
a = use_metaclass()
for proto in protocols:
s = self.dumps(a, proto)
b = self.loads(s)
self.assertEqual(a.__class__, b.__class__)
def test_structseq(self):
import time
import os
t = time.localtime()
for proto in protocols:
s = self.dumps(t, proto)
u = self.loads(s)
self.assertEqual(t, u)
if hasattr(os, "stat"):
t = os.stat(os.curdir)
s = self.dumps(t, proto)
u = self.loads(s)
self.assertEqual(t, u)
if hasattr(os, "statvfs"):
t = os.statvfs(os.curdir)
s = self.dumps(t, proto)
u = self.loads(s)
self.assertEqual(t, u)
# Tests for protocol 2
def test_proto(self):
build_none = pickle.NONE + pickle.STOP
for proto in protocols:
expected = build_none
if proto >= 2:
expected = pickle.PROTO + bytes([proto]) + expected
p = self.dumps(None, proto)
self.assertEqual(p, expected)
oob = protocols[-1] + 1 # a future protocol
badpickle = pickle.PROTO + bytes([oob]) + build_none
try:
self.loads(badpickle)
except ValueError as detail:
self.failUnless(str(detail).startswith(
"unsupported pickle protocol"))
else:
self.fail("expected bad protocol number to raise ValueError")
def test_long1(self):
x = 12345678910111213141516178920
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
def test_long4(self):
x = 12345678910111213141516178920 << (256*8)
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
def test_short_tuples(self):
# Map (proto, len(tuple)) to expected opcode.
expected_opcode = {(0, 0): pickle.TUPLE,
(0, 1): pickle.TUPLE,
(0, 2): pickle.TUPLE,
(0, 3): pickle.TUPLE,
(0, 4): pickle.TUPLE,
(1, 0): pickle.EMPTY_TUPLE,
(1, 1): pickle.TUPLE,
(1, 2): pickle.TUPLE,
(1, 3): pickle.TUPLE,
(1, 4): pickle.TUPLE,
(2, 0): pickle.EMPTY_TUPLE,
(2, 1): pickle.TUPLE1,
(2, 2): pickle.TUPLE2,
(2, 3): pickle.TUPLE3,
(2, 4): pickle.TUPLE,
(3, 0): pickle.EMPTY_TUPLE,
(3, 1): pickle.TUPLE1,
(3, 2): pickle.TUPLE2,
(3, 3): pickle.TUPLE3,
(3, 4): pickle.TUPLE,
}
a = ()
b = (1,)
c = (1, 2)
d = (1, 2, 3)
e = (1, 2, 3, 4)
for proto in protocols:
for x in a, b, c, d, e:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y, (proto, x, s, y))
expected = expected_opcode[proto, len(x)]
self.assertEqual(opcode_in_pickle(expected, s), True)
def test_singletons(self):
# Map (proto, singleton) to expected opcode.
expected_opcode = {(0, None): pickle.NONE,
(1, None): pickle.NONE,
(2, None): pickle.NONE,
(3, None): pickle.NONE,
(0, True): pickle.INT,
(1, True): pickle.INT,
(2, True): pickle.NEWTRUE,
(3, True): pickle.NEWTRUE,
(0, False): pickle.INT,
(1, False): pickle.INT,
(2, False): pickle.NEWFALSE,
(3, False): pickle.NEWFALSE,
}
for proto in protocols:
for x in None, False, True:
s = self.dumps(x, proto)
y = self.loads(s)
self.assert_(x is y, (proto, x, s, y))
expected = expected_opcode[proto, x]
self.assertEqual(opcode_in_pickle(expected, s), True)
def test_newobj_tuple(self):
x = MyTuple([1, 2, 3])
x.foo = 42
x.bar = "hello"
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(tuple(x), tuple(y))
self.assertEqual(x.__dict__, y.__dict__)
def test_newobj_list(self):
x = MyList([1, 2, 3])
x.foo = 42
x.bar = "hello"
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
def test_newobj_generic(self):
for proto in protocols:
for C in myclasses:
B = C.__base__
x = C(C.sample)
x.foo = 42
s = self.dumps(x, proto)
y = self.loads(s)
detail = (proto, C, B, x, y, type(y))
self.assertEqual(B(x), B(y), detail)
self.assertEqual(x.__dict__, y.__dict__, detail)
# Register a type with copyreg, with extension code extcode. Pickle
# an object of that type. Check that the resulting pickle uses opcode
# (EXT[124]) under proto 2, and not in proto 1.
def produce_global_ext(self, extcode, opcode):
e = ExtensionSaver(extcode)
try:
copyreg.add_extension(__name__, "MyList", extcode)
x = MyList([1, 2, 3])
x.foo = 42
x.bar = "hello"
# Dump using protocol 1 for comparison.
s1 = self.dumps(x, 1)
self.assert_(__name__.encode("utf-8") in s1)
self.assert_(b"MyList" in s1)
self.assertEqual(opcode_in_pickle(opcode, s1), False)
y = self.loads(s1)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
# Dump using protocol 2 for test.
s2 = self.dumps(x, 2)
self.assert_(__name__.encode("utf-8") not in s2)
self.assert_(b"MyList" not in s2)
self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2))
y = self.loads(s2)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
finally:
e.restore()
def test_global_ext1(self):
self.produce_global_ext(0x00000001, pickle.EXT1) # smallest EXT1 code
self.produce_global_ext(0x000000ff, pickle.EXT1) # largest EXT1 code
def test_global_ext2(self):
self.produce_global_ext(0x00000100, pickle.EXT2) # smallest EXT2 code
self.produce_global_ext(0x0000ffff, pickle.EXT2) # largest EXT2 code
self.produce_global_ext(0x0000abcd, pickle.EXT2) # check endianness
def test_global_ext4(self):
self.produce_global_ext(0x00010000, pickle.EXT4) # smallest EXT4 code
self.produce_global_ext(0x7fffffff, pickle.EXT4) # largest EXT4 code
self.produce_global_ext(0x12abcdef, pickle.EXT4) # check endianness
def test_list_chunking(self):
n = 10 # too small to chunk
x = list(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
num_appends = count_opcode(pickle.APPENDS, s)
self.assertEqual(num_appends, proto > 0)
n = 2500 # expect at least two chunks when proto > 0
x = list(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
num_appends = count_opcode(pickle.APPENDS, s)
if proto == 0:
self.assertEqual(num_appends, 0)
else:
self.failUnless(num_appends >= 2)
def test_dict_chunking(self):
n = 10 # too small to chunk
x = dict.fromkeys(range(n))
for proto in protocols:
s = self.dumps(x, proto)
assert isinstance(s, bytes_types)
y = self.loads(s)
self.assertEqual(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s)
self.assertEqual(num_setitems, proto > 0)
n = 2500 # expect at least two chunks when proto > 0
x = dict.fromkeys(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s)
if proto == 0:
self.assertEqual(num_setitems, 0)
else:
self.failUnless(num_setitems >= 2)
def test_simple_newobj(self):
x = object.__new__(SimpleNewObj) # avoid __init__
x.abc = 666
for proto in protocols:
s = self.dumps(x, proto)
self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2)
y = self.loads(s) # will raise TypeError if __init__ called
self.assertEqual(y.abc, 666)
self.assertEqual(x.__dict__, y.__dict__)
def test_newobj_list_slots(self):
x = SlotList([1, 2, 3])
x.foo = 42
x.bar = "hello"
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assertEqual(x.foo, y.foo)
self.assertEqual(x.bar, y.bar)
def test_reduce_overrides_default_reduce_ex(self):
for proto in 0, 1, 2:
x = REX_one()
self.assertEqual(x._reduce_called, 0)
s = self.dumps(x, proto)
self.assertEqual(x._reduce_called, 1)
y = self.loads(s)
self.assertEqual(y._reduce_called, 0)
def test_reduce_ex_called(self):
for proto in 0, 1, 2:
x = REX_two()
self.assertEqual(x._proto, None)
s = self.dumps(x, proto)
self.assertEqual(x._proto, proto)
y = self.loads(s)
self.assertEqual(y._proto, None)
def test_reduce_ex_overrides_reduce(self):
for proto in 0, 1, 2:
x = REX_three()
self.assertEqual(x._proto, None)
s = self.dumps(x, proto)
self.assertEqual(x._proto, proto)
y = self.loads(s)
self.assertEqual(y._proto, None)
def test_reduce_ex_calls_base(self):
for proto in 0, 1, 2:
x = REX_four()
self.assertEqual(x._proto, None)
s = self.dumps(x, proto)
self.assertEqual(x._proto, proto)
y = self.loads(s)
self.assertEqual(y._proto, proto)
def test_reduce_calls_base(self):
for proto in 0, 1, 2:
x = REX_five()
self.assertEqual(x._reduce_called, 0)
s = self.dumps(x, proto)
self.assertEqual(x._reduce_called, 1)
y = self.loads(s)
self.assertEqual(y._reduce_called, 1)
def test_bad_getattr(self):
x = BadGetattr()
for proto in 0, 1:
self.assertRaises(RuntimeError, self.dumps, x, proto)
# protocol 2 don't raise a RuntimeError.
d = self.dumps(x, 2)
self.assertRaises(RuntimeError, self.loads, d)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
class C(object):
def __reduce__(self):
# 4th item is not an iterator
return list, (), None, [], None
class D(object):
def __reduce__(self):
# 5th item is not an iterator
return dict, (), None, None, []
# Protocol 0 is less strict and also accept iterables.
for proto in protocols:
try:
self.dumps(C(), proto)
except (pickle.PickleError):
pass
try:
self.dumps(D(), proto)
except (pickle.PickleError):
pass
# Test classes for reduce_ex
class REX_one(object):
_reduce_called = 0
def __reduce__(self):
self._reduce_called = 1
return REX_one, ()
# No __reduce_ex__ here, but inheriting it from object
class REX_two(object):
_proto = None
def __reduce_ex__(self, proto):
self._proto = proto
return REX_two, ()
# No __reduce__ here, but inheriting it from object
class REX_three(object):
_proto = None
def __reduce_ex__(self, proto):
self._proto = proto
return REX_two, ()
def __reduce__(self):
raise TestFailed("This __reduce__ shouldn't be called")
class REX_four(object):
_proto = None
def __reduce_ex__(self, proto):
self._proto = proto
return object.__reduce_ex__(self, proto)
# Calling base class method should succeed
class REX_five(object):
_reduce_called = 0
def __reduce__(self):
self._reduce_called = 1
return object.__reduce__(self)
# This one used to fail with infinite recursion
# Test classes for newobj
class MyInt(int):
sample = 1
class MyLong(int):
sample = 1
class MyFloat(float):
sample = 1.0
class MyComplex(complex):
sample = 1.0 + 0.0j
class MyStr(str):
sample = "hello"
class MyUnicode(str):
sample = "hello \u1234"
class MyTuple(tuple):
sample = (1, 2, 3)
class MyList(list):
sample = [1, 2, 3]
class MyDict(dict):
sample = {"a": 1, "b": 2}
myclasses = [MyInt, MyLong, MyFloat,
MyComplex,
MyStr, MyUnicode,
MyTuple, MyList, MyDict]
class SlotList(MyList):
__slots__ = ["foo"]
class SimpleNewObj(object):
def __init__(self, a, b, c):
# raise an error, to make sure this isn't called
raise TypeError("SimpleNewObj.__init__() didn't expect to get called")
class BadGetattr:
def __getattr__(self, key):
self.foo
class AbstractPickleModuleTests(unittest.TestCase):
def test_dump_closed_file(self):
import os
f = open(TESTFN, "wb")
try:
f.close()
self.assertRaises(ValueError, pickle.dump, 123, f)
finally:
os.remove(TESTFN)
def test_load_closed_file(self):
import os
f = open(TESTFN, "wb")
try:
f.close()
self.assertRaises(ValueError, pickle.dump, 123, f)
finally:
os.remove(TESTFN)
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(pickle.HIGHEST_PROTOCOL, 3)
def test_callapi(self):
from io import BytesIO
f = BytesIO()
# With and without keyword arguments
pickle.dump(123, f, -1)
pickle.dump(123, file=f, protocol=-1)
pickle.dumps(123, -1)
pickle.dumps(123, protocol=-1)
pickle.Pickler(f, -1)
pickle.Pickler(f, protocol=-1)
def test_bad_init(self):
# Test issue3664 (pickle can segfault from a badly initialized Pickler).
from io import BytesIO
# Override initialization without calling __init__() of the superclass.
class BadPickler(pickle.Pickler):
def __init__(self): pass
class BadUnpickler(pickle.Unpickler):
def __init__(self): pass
self.assertRaises(pickle.PicklingError, BadPickler().dump, 0)
self.assertRaises(pickle.UnpicklingError, BadUnpickler().load)
def test_bad_input(self):
# Test issue4298
s = bytes([0x58, 0, 0, 0, 0x54])
self.assertRaises(EOFError, pickle.loads, s)
class AbstractPersistentPicklerTests(unittest.TestCase):
# This class defines persistent_id() and persistent_load()
# functions that should be used by the pickler. All even integers
# are pickled using persistent ids.
def persistent_id(self, object):
if isinstance(object, int) and object % 2 == 0:
self.id_count += 1
return str(object)
else:
return None
def persistent_load(self, oid):
self.load_count += 1
object = int(oid)
assert object % 2 == 0
return object
def test_persistence(self):
self.id_count = 0
self.load_count = 0
L = list(range(10))
self.assertEqual(self.loads(self.dumps(L)), L)
self.assertEqual(self.id_count, 5)
self.assertEqual(self.load_count, 5)
def test_bin_persistence(self):
self.id_count = 0
self.load_count = 0
L = list(range(10))
self.assertEqual(self.loads(self.dumps(L, 1)), L)
self.assertEqual(self.id_count, 5)
self.assertEqual(self.load_count, 5)
if __name__ == "__main__":
# Print some stuff that can be used to rewrite DATA{0,1,2}
from pickletools import dis
x = create_data()
for i in range(3):
p = pickle.dumps(x, i)
print("DATA{0} = (".format(i))
for j in range(0, len(p), 20):
b = bytes(p[j:j+20])
print(" {0!r}".format(b))
print(")")
print()
print("# Disassembly of DATA{0}".format(i))
print("DATA{0}_DIS = \"\"\"\\".format(i))
dis(p)
print("\"\"\"")
print()