|
1 """Unit tests for contextlib.py, and other context managers.""" |
|
2 |
|
3 |
|
4 import sys |
|
5 import os |
|
6 import decimal |
|
7 import tempfile |
|
8 import unittest |
|
9 import threading |
|
10 from contextlib import * # Tests __all__ |
|
11 from test import test_support |
|
12 |
|
13 class ContextManagerTestCase(unittest.TestCase): |
|
14 |
|
15 def test_contextmanager_plain(self): |
|
16 state = [] |
|
17 @contextmanager |
|
18 def woohoo(): |
|
19 state.append(1) |
|
20 yield 42 |
|
21 state.append(999) |
|
22 with woohoo() as x: |
|
23 self.assertEqual(state, [1]) |
|
24 self.assertEqual(x, 42) |
|
25 state.append(x) |
|
26 self.assertEqual(state, [1, 42, 999]) |
|
27 |
|
28 def test_contextmanager_finally(self): |
|
29 state = [] |
|
30 @contextmanager |
|
31 def woohoo(): |
|
32 state.append(1) |
|
33 try: |
|
34 yield 42 |
|
35 finally: |
|
36 state.append(999) |
|
37 try: |
|
38 with woohoo() as x: |
|
39 self.assertEqual(state, [1]) |
|
40 self.assertEqual(x, 42) |
|
41 state.append(x) |
|
42 raise ZeroDivisionError() |
|
43 except ZeroDivisionError: |
|
44 pass |
|
45 else: |
|
46 self.fail("Expected ZeroDivisionError") |
|
47 self.assertEqual(state, [1, 42, 999]) |
|
48 |
|
49 def test_contextmanager_no_reraise(self): |
|
50 @contextmanager |
|
51 def whee(): |
|
52 yield |
|
53 ctx = whee() |
|
54 ctx.__enter__() |
|
55 # Calling __exit__ should not result in an exception |
|
56 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None)) |
|
57 |
|
58 def test_contextmanager_trap_yield_after_throw(self): |
|
59 @contextmanager |
|
60 def whoo(): |
|
61 try: |
|
62 yield |
|
63 except: |
|
64 yield |
|
65 ctx = whoo() |
|
66 ctx.__enter__() |
|
67 self.assertRaises( |
|
68 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None |
|
69 ) |
|
70 |
|
71 def test_contextmanager_except(self): |
|
72 state = [] |
|
73 @contextmanager |
|
74 def woohoo(): |
|
75 state.append(1) |
|
76 try: |
|
77 yield 42 |
|
78 except ZeroDivisionError, e: |
|
79 state.append(e.args[0]) |
|
80 self.assertEqual(state, [1, 42, 999]) |
|
81 with woohoo() as x: |
|
82 self.assertEqual(state, [1]) |
|
83 self.assertEqual(x, 42) |
|
84 state.append(x) |
|
85 raise ZeroDivisionError(999) |
|
86 self.assertEqual(state, [1, 42, 999]) |
|
87 |
|
88 def test_contextmanager_attribs(self): |
|
89 def attribs(**kw): |
|
90 def decorate(func): |
|
91 for k,v in kw.items(): |
|
92 setattr(func,k,v) |
|
93 return func |
|
94 return decorate |
|
95 @contextmanager |
|
96 @attribs(foo='bar') |
|
97 def baz(spam): |
|
98 """Whee!""" |
|
99 self.assertEqual(baz.__name__,'baz') |
|
100 self.assertEqual(baz.foo, 'bar') |
|
101 self.assertEqual(baz.__doc__, "Whee!") |
|
102 |
|
103 class NestedTestCase(unittest.TestCase): |
|
104 |
|
105 # XXX This needs more work |
|
106 |
|
107 def test_nested(self): |
|
108 @contextmanager |
|
109 def a(): |
|
110 yield 1 |
|
111 @contextmanager |
|
112 def b(): |
|
113 yield 2 |
|
114 @contextmanager |
|
115 def c(): |
|
116 yield 3 |
|
117 with nested(a(), b(), c()) as (x, y, z): |
|
118 self.assertEqual(x, 1) |
|
119 self.assertEqual(y, 2) |
|
120 self.assertEqual(z, 3) |
|
121 |
|
122 def test_nested_cleanup(self): |
|
123 state = [] |
|
124 @contextmanager |
|
125 def a(): |
|
126 state.append(1) |
|
127 try: |
|
128 yield 2 |
|
129 finally: |
|
130 state.append(3) |
|
131 @contextmanager |
|
132 def b(): |
|
133 state.append(4) |
|
134 try: |
|
135 yield 5 |
|
136 finally: |
|
137 state.append(6) |
|
138 try: |
|
139 with nested(a(), b()) as (x, y): |
|
140 state.append(x) |
|
141 state.append(y) |
|
142 1/0 |
|
143 except ZeroDivisionError: |
|
144 self.assertEqual(state, [1, 4, 2, 5, 6, 3]) |
|
145 else: |
|
146 self.fail("Didn't raise ZeroDivisionError") |
|
147 |
|
148 def test_nested_right_exception(self): |
|
149 state = [] |
|
150 @contextmanager |
|
151 def a(): |
|
152 yield 1 |
|
153 class b(object): |
|
154 def __enter__(self): |
|
155 return 2 |
|
156 def __exit__(self, *exc_info): |
|
157 try: |
|
158 raise Exception() |
|
159 except: |
|
160 pass |
|
161 try: |
|
162 with nested(a(), b()) as (x, y): |
|
163 1/0 |
|
164 except ZeroDivisionError: |
|
165 self.assertEqual((x, y), (1, 2)) |
|
166 except Exception: |
|
167 self.fail("Reraised wrong exception") |
|
168 else: |
|
169 self.fail("Didn't raise ZeroDivisionError") |
|
170 |
|
171 def test_nested_b_swallows(self): |
|
172 @contextmanager |
|
173 def a(): |
|
174 yield |
|
175 @contextmanager |
|
176 def b(): |
|
177 try: |
|
178 yield |
|
179 except: |
|
180 # Swallow the exception |
|
181 pass |
|
182 try: |
|
183 with nested(a(), b()): |
|
184 1/0 |
|
185 except ZeroDivisionError: |
|
186 self.fail("Didn't swallow ZeroDivisionError") |
|
187 |
|
188 def test_nested_break(self): |
|
189 @contextmanager |
|
190 def a(): |
|
191 yield |
|
192 state = 0 |
|
193 while True: |
|
194 state += 1 |
|
195 with nested(a(), a()): |
|
196 break |
|
197 state += 10 |
|
198 self.assertEqual(state, 1) |
|
199 |
|
200 def test_nested_continue(self): |
|
201 @contextmanager |
|
202 def a(): |
|
203 yield |
|
204 state = 0 |
|
205 while state < 3: |
|
206 state += 1 |
|
207 with nested(a(), a()): |
|
208 continue |
|
209 state += 10 |
|
210 self.assertEqual(state, 3) |
|
211 |
|
212 def test_nested_return(self): |
|
213 @contextmanager |
|
214 def a(): |
|
215 try: |
|
216 yield |
|
217 except: |
|
218 pass |
|
219 def foo(): |
|
220 with nested(a(), a()): |
|
221 return 1 |
|
222 return 10 |
|
223 self.assertEqual(foo(), 1) |
|
224 |
|
225 class ClosingTestCase(unittest.TestCase): |
|
226 |
|
227 # XXX This needs more work |
|
228 |
|
229 def test_closing(self): |
|
230 state = [] |
|
231 class C: |
|
232 def close(self): |
|
233 state.append(1) |
|
234 x = C() |
|
235 self.assertEqual(state, []) |
|
236 with closing(x) as y: |
|
237 self.assertEqual(x, y) |
|
238 self.assertEqual(state, [1]) |
|
239 |
|
240 def test_closing_error(self): |
|
241 state = [] |
|
242 class C: |
|
243 def close(self): |
|
244 state.append(1) |
|
245 x = C() |
|
246 self.assertEqual(state, []) |
|
247 try: |
|
248 with closing(x) as y: |
|
249 self.assertEqual(x, y) |
|
250 1/0 |
|
251 except ZeroDivisionError: |
|
252 self.assertEqual(state, [1]) |
|
253 else: |
|
254 self.fail("Didn't raise ZeroDivisionError") |
|
255 |
|
256 class FileContextTestCase(unittest.TestCase): |
|
257 |
|
258 def testWithOpen(self): |
|
259 tfn = tempfile.mktemp() |
|
260 try: |
|
261 f = None |
|
262 with open(tfn, "w") as f: |
|
263 self.failIf(f.closed) |
|
264 f.write("Booh\n") |
|
265 self.failUnless(f.closed) |
|
266 f = None |
|
267 try: |
|
268 with open(tfn, "r") as f: |
|
269 self.failIf(f.closed) |
|
270 self.assertEqual(f.read(), "Booh\n") |
|
271 1/0 |
|
272 except ZeroDivisionError: |
|
273 self.failUnless(f.closed) |
|
274 else: |
|
275 self.fail("Didn't raise ZeroDivisionError") |
|
276 finally: |
|
277 try: |
|
278 os.remove(tfn) |
|
279 except os.error: |
|
280 pass |
|
281 |
|
282 class LockContextTestCase(unittest.TestCase): |
|
283 |
|
284 def boilerPlate(self, lock, locked): |
|
285 self.failIf(locked()) |
|
286 with lock: |
|
287 self.failUnless(locked()) |
|
288 self.failIf(locked()) |
|
289 try: |
|
290 with lock: |
|
291 self.failUnless(locked()) |
|
292 1/0 |
|
293 except ZeroDivisionError: |
|
294 self.failIf(locked()) |
|
295 else: |
|
296 self.fail("Didn't raise ZeroDivisionError") |
|
297 |
|
298 def testWithLock(self): |
|
299 lock = threading.Lock() |
|
300 self.boilerPlate(lock, lock.locked) |
|
301 |
|
302 def testWithRLock(self): |
|
303 lock = threading.RLock() |
|
304 self.boilerPlate(lock, lock._is_owned) |
|
305 |
|
306 def testWithCondition(self): |
|
307 lock = threading.Condition() |
|
308 def locked(): |
|
309 return lock._is_owned() |
|
310 self.boilerPlate(lock, locked) |
|
311 |
|
312 def testWithSemaphore(self): |
|
313 lock = threading.Semaphore() |
|
314 def locked(): |
|
315 if lock.acquire(False): |
|
316 lock.release() |
|
317 return False |
|
318 else: |
|
319 return True |
|
320 self.boilerPlate(lock, locked) |
|
321 |
|
322 def testWithBoundedSemaphore(self): |
|
323 lock = threading.BoundedSemaphore() |
|
324 def locked(): |
|
325 if lock.acquire(False): |
|
326 lock.release() |
|
327 return False |
|
328 else: |
|
329 return True |
|
330 self.boilerPlate(lock, locked) |
|
331 |
|
332 # This is needed to make the test actually run under regrtest.py! |
|
333 def test_main(): |
|
334 test_support.run_unittest(__name__) |
|
335 |
|
336 if __name__ == "__main__": |
|
337 test_main() |