|
1 """Unit tests for contextlib.py, and other context managers.""" |
|
2 |
|
3 from __future__ import with_statement |
|
4 |
|
5 import sys |
|
6 import os |
|
7 import decimal |
|
8 import tempfile |
|
9 import unittest |
|
10 import threading |
|
11 from contextlib import * # Tests __all__ |
|
12 from test.test_support import run_suite |
|
13 |
|
14 class ContextManagerTestCase(unittest.TestCase): |
|
15 |
|
16 def test_contextmanager_plain(self): |
|
17 state = [] |
|
18 @contextmanager |
|
19 def woohoo(): |
|
20 state.append(1) |
|
21 yield 42 |
|
22 state.append(999) |
|
23 with woohoo() as x: |
|
24 self.assertEqual(state, [1]) |
|
25 self.assertEqual(x, 42) |
|
26 state.append(x) |
|
27 self.assertEqual(state, [1, 42, 999]) |
|
28 |
|
29 def test_contextmanager_finally(self): |
|
30 state = [] |
|
31 @contextmanager |
|
32 def woohoo(): |
|
33 state.append(1) |
|
34 try: |
|
35 yield 42 |
|
36 finally: |
|
37 state.append(999) |
|
38 try: |
|
39 with woohoo() as x: |
|
40 self.assertEqual(state, [1]) |
|
41 self.assertEqual(x, 42) |
|
42 state.append(x) |
|
43 raise ZeroDivisionError() |
|
44 except ZeroDivisionError: |
|
45 pass |
|
46 else: |
|
47 self.fail("Expected ZeroDivisionError") |
|
48 self.assertEqual(state, [1, 42, 999]) |
|
49 |
|
50 def test_contextmanager_no_reraise(self): |
|
51 @contextmanager |
|
52 def whee(): |
|
53 yield |
|
54 ctx = whee() |
|
55 ctx.__enter__() |
|
56 # Calling __exit__ should not result in an exception |
|
57 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None)) |
|
58 |
|
59 def test_contextmanager_trap_yield_after_throw(self): |
|
60 @contextmanager |
|
61 def whoo(): |
|
62 try: |
|
63 yield |
|
64 except: |
|
65 yield |
|
66 ctx = whoo() |
|
67 ctx.__enter__() |
|
68 self.assertRaises( |
|
69 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None |
|
70 ) |
|
71 |
|
72 def test_contextmanager_except(self): |
|
73 state = [] |
|
74 @contextmanager |
|
75 def woohoo(): |
|
76 state.append(1) |
|
77 try: |
|
78 yield 42 |
|
79 except ZeroDivisionError, e: |
|
80 state.append(e.args[0]) |
|
81 self.assertEqual(state, [1, 42, 999]) |
|
82 with woohoo() as x: |
|
83 self.assertEqual(state, [1]) |
|
84 self.assertEqual(x, 42) |
|
85 state.append(x) |
|
86 raise ZeroDivisionError(999) |
|
87 self.assertEqual(state, [1, 42, 999]) |
|
88 |
|
89 def test_contextmanager_attribs(self): |
|
90 def attribs(**kw): |
|
91 def decorate(func): |
|
92 for k,v in kw.items(): |
|
93 setattr(func,k,v) |
|
94 return func |
|
95 return decorate |
|
96 @contextmanager |
|
97 @attribs(foo='bar') |
|
98 def baz(spam): |
|
99 """Whee!""" |
|
100 self.assertEqual(baz.__name__,'baz') |
|
101 self.assertEqual(baz.foo, 'bar') |
|
102 self.assertEqual(baz.__doc__, "Whee!") |
|
103 |
|
104 class NestedTestCase(unittest.TestCase): |
|
105 |
|
106 # XXX This needs more work |
|
107 |
|
108 def test_nested(self): |
|
109 @contextmanager |
|
110 def a(): |
|
111 yield 1 |
|
112 @contextmanager |
|
113 def b(): |
|
114 yield 2 |
|
115 @contextmanager |
|
116 def c(): |
|
117 yield 3 |
|
118 with nested(a(), b(), c()) as (x, y, z): |
|
119 self.assertEqual(x, 1) |
|
120 self.assertEqual(y, 2) |
|
121 self.assertEqual(z, 3) |
|
122 |
|
123 def test_nested_cleanup(self): |
|
124 state = [] |
|
125 @contextmanager |
|
126 def a(): |
|
127 state.append(1) |
|
128 try: |
|
129 yield 2 |
|
130 finally: |
|
131 state.append(3) |
|
132 @contextmanager |
|
133 def b(): |
|
134 state.append(4) |
|
135 try: |
|
136 yield 5 |
|
137 finally: |
|
138 state.append(6) |
|
139 try: |
|
140 with nested(a(), b()) as (x, y): |
|
141 state.append(x) |
|
142 state.append(y) |
|
143 1/0 |
|
144 except ZeroDivisionError: |
|
145 self.assertEqual(state, [1, 4, 2, 5, 6, 3]) |
|
146 else: |
|
147 self.fail("Didn't raise ZeroDivisionError") |
|
148 |
|
149 def test_nested_right_exception(self): |
|
150 state = [] |
|
151 @contextmanager |
|
152 def a(): |
|
153 yield 1 |
|
154 class b(object): |
|
155 def __enter__(self): |
|
156 return 2 |
|
157 def __exit__(self, *exc_info): |
|
158 try: |
|
159 raise Exception() |
|
160 except: |
|
161 pass |
|
162 try: |
|
163 with nested(a(), b()) as (x, y): |
|
164 1/0 |
|
165 except ZeroDivisionError: |
|
166 self.assertEqual((x, y), (1, 2)) |
|
167 except Exception: |
|
168 self.fail("Reraised wrong exception") |
|
169 else: |
|
170 self.fail("Didn't raise ZeroDivisionError") |
|
171 |
|
172 def test_nested_b_swallows(self): |
|
173 @contextmanager |
|
174 def a(): |
|
175 yield |
|
176 @contextmanager |
|
177 def b(): |
|
178 try: |
|
179 yield |
|
180 except: |
|
181 # Swallow the exception |
|
182 pass |
|
183 try: |
|
184 with nested(a(), b()): |
|
185 1/0 |
|
186 except ZeroDivisionError: |
|
187 self.fail("Didn't swallow ZeroDivisionError") |
|
188 |
|
189 def test_nested_break(self): |
|
190 @contextmanager |
|
191 def a(): |
|
192 yield |
|
193 state = 0 |
|
194 while True: |
|
195 state += 1 |
|
196 with nested(a(), a()): |
|
197 break |
|
198 state += 10 |
|
199 self.assertEqual(state, 1) |
|
200 |
|
201 def test_nested_continue(self): |
|
202 @contextmanager |
|
203 def a(): |
|
204 yield |
|
205 state = 0 |
|
206 while state < 3: |
|
207 state += 1 |
|
208 with nested(a(), a()): |
|
209 continue |
|
210 state += 10 |
|
211 self.assertEqual(state, 3) |
|
212 |
|
213 def test_nested_return(self): |
|
214 @contextmanager |
|
215 def a(): |
|
216 try: |
|
217 yield |
|
218 except: |
|
219 pass |
|
220 def foo(): |
|
221 with nested(a(), a()): |
|
222 return 1 |
|
223 return 10 |
|
224 self.assertEqual(foo(), 1) |
|
225 |
|
226 class ClosingTestCase(unittest.TestCase): |
|
227 |
|
228 # XXX This needs more work |
|
229 |
|
230 def test_closing(self): |
|
231 state = [] |
|
232 class C: |
|
233 def close(self): |
|
234 state.append(1) |
|
235 x = C() |
|
236 self.assertEqual(state, []) |
|
237 with closing(x) as y: |
|
238 self.assertEqual(x, y) |
|
239 self.assertEqual(state, [1]) |
|
240 |
|
241 def test_closing_error(self): |
|
242 state = [] |
|
243 class C: |
|
244 def close(self): |
|
245 state.append(1) |
|
246 x = C() |
|
247 self.assertEqual(state, []) |
|
248 try: |
|
249 with closing(x) as y: |
|
250 self.assertEqual(x, y) |
|
251 1/0 |
|
252 except ZeroDivisionError: |
|
253 self.assertEqual(state, [1]) |
|
254 else: |
|
255 self.fail("Didn't raise ZeroDivisionError") |
|
256 |
|
257 class FileContextTestCase(unittest.TestCase): |
|
258 |
|
259 def testWithOpen(self): |
|
260 tfn = tempfile.mktemp() |
|
261 try: |
|
262 f = None |
|
263 with open(tfn, "w") as f: |
|
264 self.failIf(f.closed) |
|
265 f.write("Booh\n") |
|
266 self.failUnless(f.closed) |
|
267 f = None |
|
268 try: |
|
269 with open(tfn, "r") as f: |
|
270 self.failIf(f.closed) |
|
271 self.assertEqual(f.read(), "Booh\n") |
|
272 1/0 |
|
273 except ZeroDivisionError: |
|
274 self.failUnless(f.closed) |
|
275 else: |
|
276 self.fail("Didn't raise ZeroDivisionError") |
|
277 finally: |
|
278 try: |
|
279 os.remove(tfn) |
|
280 except os.error: |
|
281 pass |
|
282 |
|
283 class LockContextTestCase(unittest.TestCase): |
|
284 |
|
285 def boilerPlate(self, lock, locked): |
|
286 self.failIf(locked()) |
|
287 with lock: |
|
288 self.failUnless(locked()) |
|
289 self.failIf(locked()) |
|
290 try: |
|
291 with lock: |
|
292 self.failUnless(locked()) |
|
293 1/0 |
|
294 except ZeroDivisionError: |
|
295 self.failIf(locked()) |
|
296 else: |
|
297 self.fail("Didn't raise ZeroDivisionError") |
|
298 |
|
299 def testWithLock(self): |
|
300 lock = threading.Lock() |
|
301 self.boilerPlate(lock, lock.locked) |
|
302 |
|
303 def testWithRLock(self): |
|
304 lock = threading.RLock() |
|
305 self.boilerPlate(lock, lock._is_owned) |
|
306 |
|
307 def testWithCondition(self): |
|
308 lock = threading.Condition() |
|
309 def locked(): |
|
310 return lock._is_owned() |
|
311 self.boilerPlate(lock, locked) |
|
312 |
|
313 def testWithSemaphore(self): |
|
314 lock = threading.Semaphore() |
|
315 def locked(): |
|
316 if lock.acquire(False): |
|
317 lock.release() |
|
318 return False |
|
319 else: |
|
320 return True |
|
321 self.boilerPlate(lock, locked) |
|
322 |
|
323 def testWithBoundedSemaphore(self): |
|
324 lock = threading.BoundedSemaphore() |
|
325 def locked(): |
|
326 if lock.acquire(False): |
|
327 lock.release() |
|
328 return False |
|
329 else: |
|
330 return True |
|
331 self.boilerPlate(lock, locked) |
|
332 |
|
333 # This is needed to make the test actually run under regrtest.py! |
|
334 def test_main(): |
|
335 run_suite( |
|
336 unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__]) |
|
337 ) |
|
338 |
|
339 if __name__ == "__main__": |
|
340 test_main() |