|
1 import functools |
|
2 import unittest |
|
3 from test import test_support |
|
4 from weakref import proxy |
|
5 |
|
6 @staticmethod |
|
7 def PythonPartial(func, *args, **keywords): |
|
8 'Pure Python approximation of partial()' |
|
9 def newfunc(*fargs, **fkeywords): |
|
10 newkeywords = keywords.copy() |
|
11 newkeywords.update(fkeywords) |
|
12 return func(*(args + fargs), **newkeywords) |
|
13 newfunc.func = func |
|
14 newfunc.args = args |
|
15 newfunc.keywords = keywords |
|
16 return newfunc |
|
17 |
|
18 def capture(*args, **kw): |
|
19 """capture all positional and keyword arguments""" |
|
20 return args, kw |
|
21 |
|
22 class TestPartial(unittest.TestCase): |
|
23 |
|
24 thetype = functools.partial |
|
25 |
|
26 def test_basic_examples(self): |
|
27 p = self.thetype(capture, 1, 2, a=10, b=20) |
|
28 self.assertEqual(p(3, 4, b=30, c=40), |
|
29 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) |
|
30 p = self.thetype(map, lambda x: x*10) |
|
31 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40]) |
|
32 |
|
33 def test_attributes(self): |
|
34 p = self.thetype(capture, 1, 2, a=10, b=20) |
|
35 # attributes should be readable |
|
36 self.assertEqual(p.func, capture) |
|
37 self.assertEqual(p.args, (1, 2)) |
|
38 self.assertEqual(p.keywords, dict(a=10, b=20)) |
|
39 # attributes should not be writable |
|
40 if not isinstance(self.thetype, type): |
|
41 return |
|
42 self.assertRaises(TypeError, setattr, p, 'func', map) |
|
43 self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) |
|
44 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) |
|
45 |
|
46 def test_argument_checking(self): |
|
47 self.assertRaises(TypeError, self.thetype) # need at least a func arg |
|
48 try: |
|
49 self.thetype(2)() |
|
50 except TypeError: |
|
51 pass |
|
52 else: |
|
53 self.fail('First arg not checked for callability') |
|
54 |
|
55 def test_protection_of_callers_dict_argument(self): |
|
56 # a caller's dictionary should not be altered by partial |
|
57 def func(a=10, b=20): |
|
58 return a |
|
59 d = {'a':3} |
|
60 p = self.thetype(func, a=5) |
|
61 self.assertEqual(p(**d), 3) |
|
62 self.assertEqual(d, {'a':3}) |
|
63 p(b=7) |
|
64 self.assertEqual(d, {'a':3}) |
|
65 |
|
66 def test_arg_combinations(self): |
|
67 # exercise special code paths for zero args in either partial |
|
68 # object or the caller |
|
69 p = self.thetype(capture) |
|
70 self.assertEqual(p(), ((), {})) |
|
71 self.assertEqual(p(1,2), ((1,2), {})) |
|
72 p = self.thetype(capture, 1, 2) |
|
73 self.assertEqual(p(), ((1,2), {})) |
|
74 self.assertEqual(p(3,4), ((1,2,3,4), {})) |
|
75 |
|
76 def test_kw_combinations(self): |
|
77 # exercise special code paths for no keyword args in |
|
78 # either the partial object or the caller |
|
79 p = self.thetype(capture) |
|
80 self.assertEqual(p(), ((), {})) |
|
81 self.assertEqual(p(a=1), ((), {'a':1})) |
|
82 p = self.thetype(capture, a=1) |
|
83 self.assertEqual(p(), ((), {'a':1})) |
|
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) |
|
85 # keyword args in the call override those in the partial object |
|
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) |
|
87 |
|
88 def test_positional(self): |
|
89 # make sure positional arguments are captured correctly |
|
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: |
|
91 p = self.thetype(capture, *args) |
|
92 expected = args + ('x',) |
|
93 got, empty = p('x') |
|
94 self.failUnless(expected == got and empty == {}) |
|
95 |
|
96 def test_keyword(self): |
|
97 # make sure keyword arguments are captured correctly |
|
98 for a in ['a', 0, None, 3.5]: |
|
99 p = self.thetype(capture, a=a) |
|
100 expected = {'a':a,'x':None} |
|
101 empty, got = p(x=None) |
|
102 self.failUnless(expected == got and empty == ()) |
|
103 |
|
104 def test_no_side_effects(self): |
|
105 # make sure there are no side effects that affect subsequent calls |
|
106 p = self.thetype(capture, 0, a=1) |
|
107 args1, kw1 = p(1, b=2) |
|
108 self.failUnless(args1 == (0,1) and kw1 == {'a':1,'b':2}) |
|
109 args2, kw2 = p() |
|
110 self.failUnless(args2 == (0,) and kw2 == {'a':1}) |
|
111 |
|
112 def test_error_propagation(self): |
|
113 def f(x, y): |
|
114 x / y |
|
115 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0)) |
|
116 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0) |
|
117 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0) |
|
118 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1) |
|
119 |
|
120 def test_attributes(self): |
|
121 p = self.thetype(hex) |
|
122 try: |
|
123 del p.__dict__ |
|
124 except TypeError: |
|
125 pass |
|
126 else: |
|
127 self.fail('partial object allowed __dict__ to be deleted') |
|
128 |
|
129 def test_weakref(self): |
|
130 f = self.thetype(int, base=16) |
|
131 p = proxy(f) |
|
132 self.assertEqual(f.func, p.func) |
|
133 f = None |
|
134 self.assertRaises(ReferenceError, getattr, p, 'func') |
|
135 |
|
136 def test_with_bound_and_unbound_methods(self): |
|
137 data = map(str, range(10)) |
|
138 join = self.thetype(str.join, '') |
|
139 self.assertEqual(join(data), '0123456789') |
|
140 join = self.thetype(''.join) |
|
141 self.assertEqual(join(data), '0123456789') |
|
142 |
|
143 class PartialSubclass(functools.partial): |
|
144 pass |
|
145 |
|
146 class TestPartialSubclass(TestPartial): |
|
147 |
|
148 thetype = PartialSubclass |
|
149 |
|
150 |
|
151 class TestPythonPartial(TestPartial): |
|
152 |
|
153 thetype = PythonPartial |
|
154 |
|
155 class TestUpdateWrapper(unittest.TestCase): |
|
156 |
|
157 def check_wrapper(self, wrapper, wrapped, |
|
158 assigned=functools.WRAPPER_ASSIGNMENTS, |
|
159 updated=functools.WRAPPER_UPDATES): |
|
160 # Check attributes were assigned |
|
161 for name in assigned: |
|
162 self.failUnless(getattr(wrapper, name) is getattr(wrapped, name)) |
|
163 # Check attributes were updated |
|
164 for name in updated: |
|
165 wrapper_attr = getattr(wrapper, name) |
|
166 wrapped_attr = getattr(wrapped, name) |
|
167 for key in wrapped_attr: |
|
168 self.failUnless(wrapped_attr[key] is wrapper_attr[key]) |
|
169 |
|
170 def test_default_update(self): |
|
171 def f(): |
|
172 """This is a test""" |
|
173 pass |
|
174 f.attr = 'This is also a test' |
|
175 def wrapper(): |
|
176 pass |
|
177 functools.update_wrapper(wrapper, f) |
|
178 self.check_wrapper(wrapper, f) |
|
179 self.assertEqual(wrapper.__name__, 'f') |
|
180 self.assertEqual(wrapper.__doc__, 'This is a test') |
|
181 self.assertEqual(wrapper.attr, 'This is also a test') |
|
182 |
|
183 def test_no_update(self): |
|
184 def f(): |
|
185 """This is a test""" |
|
186 pass |
|
187 f.attr = 'This is also a test' |
|
188 def wrapper(): |
|
189 pass |
|
190 functools.update_wrapper(wrapper, f, (), ()) |
|
191 self.check_wrapper(wrapper, f, (), ()) |
|
192 self.assertEqual(wrapper.__name__, 'wrapper') |
|
193 self.assertEqual(wrapper.__doc__, None) |
|
194 self.failIf(hasattr(wrapper, 'attr')) |
|
195 |
|
196 def test_selective_update(self): |
|
197 def f(): |
|
198 pass |
|
199 f.attr = 'This is a different test' |
|
200 f.dict_attr = dict(a=1, b=2, c=3) |
|
201 def wrapper(): |
|
202 pass |
|
203 wrapper.dict_attr = {} |
|
204 assign = ('attr',) |
|
205 update = ('dict_attr',) |
|
206 functools.update_wrapper(wrapper, f, assign, update) |
|
207 self.check_wrapper(wrapper, f, assign, update) |
|
208 self.assertEqual(wrapper.__name__, 'wrapper') |
|
209 self.assertEqual(wrapper.__doc__, None) |
|
210 self.assertEqual(wrapper.attr, 'This is a different test') |
|
211 self.assertEqual(wrapper.dict_attr, f.dict_attr) |
|
212 |
|
213 def test_builtin_update(self): |
|
214 # Test for bug #1576241 |
|
215 def wrapper(): |
|
216 pass |
|
217 functools.update_wrapper(wrapper, max) |
|
218 self.assertEqual(wrapper.__name__, 'max') |
|
219 self.assert_(wrapper.__doc__.startswith('max(')) |
|
220 |
|
221 class TestWraps(TestUpdateWrapper): |
|
222 |
|
223 def test_default_update(self): |
|
224 def f(): |
|
225 """This is a test""" |
|
226 pass |
|
227 f.attr = 'This is also a test' |
|
228 @functools.wraps(f) |
|
229 def wrapper(): |
|
230 pass |
|
231 self.check_wrapper(wrapper, f) |
|
232 self.assertEqual(wrapper.__name__, 'f') |
|
233 self.assertEqual(wrapper.__doc__, 'This is a test') |
|
234 self.assertEqual(wrapper.attr, 'This is also a test') |
|
235 |
|
236 def test_no_update(self): |
|
237 def f(): |
|
238 """This is a test""" |
|
239 pass |
|
240 f.attr = 'This is also a test' |
|
241 @functools.wraps(f, (), ()) |
|
242 def wrapper(): |
|
243 pass |
|
244 self.check_wrapper(wrapper, f, (), ()) |
|
245 self.assertEqual(wrapper.__name__, 'wrapper') |
|
246 self.assertEqual(wrapper.__doc__, None) |
|
247 self.failIf(hasattr(wrapper, 'attr')) |
|
248 |
|
249 def test_selective_update(self): |
|
250 def f(): |
|
251 pass |
|
252 f.attr = 'This is a different test' |
|
253 f.dict_attr = dict(a=1, b=2, c=3) |
|
254 def add_dict_attr(f): |
|
255 f.dict_attr = {} |
|
256 return f |
|
257 assign = ('attr',) |
|
258 update = ('dict_attr',) |
|
259 @functools.wraps(f, assign, update) |
|
260 @add_dict_attr |
|
261 def wrapper(): |
|
262 pass |
|
263 self.check_wrapper(wrapper, f, assign, update) |
|
264 self.assertEqual(wrapper.__name__, 'wrapper') |
|
265 self.assertEqual(wrapper.__doc__, None) |
|
266 self.assertEqual(wrapper.attr, 'This is a different test') |
|
267 self.assertEqual(wrapper.dict_attr, f.dict_attr) |
|
268 |
|
269 |
|
270 |
|
271 def test_main(verbose=None): |
|
272 import sys |
|
273 test_classes = ( |
|
274 TestPartial, |
|
275 TestPartialSubclass, |
|
276 TestPythonPartial, |
|
277 TestUpdateWrapper, |
|
278 TestWraps |
|
279 ) |
|
280 test_support.run_unittest(*test_classes) |
|
281 |
|
282 # verify reference counting |
|
283 if verbose and hasattr(sys, "gettotalrefcount"): |
|
284 import gc |
|
285 counts = [None] * 5 |
|
286 for i in xrange(len(counts)): |
|
287 test_support.run_unittest(*test_classes) |
|
288 gc.collect() |
|
289 counts[i] = sys.gettotalrefcount() |
|
290 print counts |
|
291 |
|
292 if __name__ == '__main__': |
|
293 test_main(verbose=True) |